18#ifndef __NNFW_CKER_TRANSPOSE_H__
19#define __NNFW_CKER_TRANSPOSE_H__
34 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
36 const int unextended_output_size = unextended_output_shape.
DimensionsCount();
38 assert(unextended_output_size <= 6);
39 assert(unextended_output_size == params.
perm_count);
40 const Shape input_shape = Shape::ExtendedShape(6, unextended_input_shape);
42 const int input_ext_size = 6 - unextended_input_shape.
DimensionsCount();
43 const int output_ext_size = 6 - unextended_output_size;
46 for (
int i = 0; i < output_ext_size; ++i)
50 for (
int i = 0; i < unextended_output_size; ++i)
52 extended_perm[i + output_ext_size] = params.
perm[i] + input_ext_size;
56 for (
int k = 0; k < 6; k++)
63 for (o[5] = 0; o[5] < out_sizes[5]; o[5]++)
65 i[extended_perm[5]] = o[5];
66 for (o[4] = 0; o[4] < out_sizes[4]; o[4]++)
68 i[extended_perm[4]] = o[4];
69 for (o[3] = 0; o[3] < out_sizes[3]; o[3]++)
71 i[extended_perm[3]] = o[3];
72 for (o[2] = 0; o[2] < out_sizes[2]; o[2]++)
74 i[extended_perm[2]] = o[2];
75 for (o[1] = 0; o[1] < out_sizes[1]; o[1]++)
77 i[extended_perm[1]] = o[1];
78 for (o[0] = 0; o[0] < out_sizes[0]; o[0]++)
80 i[extended_perm[0]] = o[0];
92 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
100 TransposeImpl<int8_t>(params, unextended_input_shape,
101 reinterpret_cast<const int8_t *
>(input_data), unextended_output_shape,
102 reinterpret_cast<int8_t *
>(output_data));
105 TransposeImpl<int16_t>(params, unextended_input_shape,
106 reinterpret_cast<const int16_t *
>(input_data), unextended_output_shape,
107 reinterpret_cast<int16_t *
>(output_data));
111 TransposeImpl<int32_t>(params, unextended_input_shape,
112 reinterpret_cast<const int32_t *
>(input_data), unextended_output_shape,
113 reinterpret_cast<int32_t *
>(output_data));
116 TransposeImpl<int64_t>(params, unextended_input_shape,
117 reinterpret_cast<const int64_t *
>(input_data), unextended_output_shape,
118 reinterpret_cast<int64_t *
>(output_data));
127bool IsTranspose2DApplicable(
const TransposeParams ¶ms,
const Shape &input_shape,
int *dim0,
130 const int dims_cnt = input_shape.DimensionsCount();
134 *dim0 = input_shape.Dims(0);
135 *dim1 = input_shape.Dims(1);
139 const int first_perm = params.perm[0];
140 for (
int i = 1; i < dims_cnt; ++i)
142 int rebased = params.perm[i] - first_perm;
154 for (
int i = 0; i < dims_cnt; ++i)
158 *dim0 *= input_shape.Dims(i);
162 *dim1 *= input_shape.Dims(i);
170 const int dims_cnt = input_shape->DimensionsCount();
171 assert(params->perm_count == dims_cnt);
173 bool foundOneSizeDim =
false;
174 for (
int i = 0; i < dims_cnt; ++i)
176 if (input_shape->Dims(i) == 1)
178 foundOneSizeDim =
true;
184 if (!foundOneSizeDim)
188 if (input_shape->FlatSize() == 1)
190 input_shape->Resize(1);
191 input_shape->SetDim(0, 1);
194 params->perm_count = 1;
200 int new_dims_cnt = 0;
201 for (
int i = 0; i < dims_cnt; ++i)
203 if (input_shape->Dims(i) == 1)
207 input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
210 input_shape->Resize(new_dims_cnt);
213 TransposeParams new_params;
215 for (
int i = 0; i < dims_cnt; ++i)
221 new_params.perm[new_dims_cnt] = params->perm[i];
226 new_params.perm_count = new_dims_cnt;
228 for (
int i = 0; i < new_dims_cnt; ++i)
230 int min_val_idx = -1;
231 for (
int j = 0; j < new_dims_cnt; ++j)
233 if (new_params.perm[j] >= i &&
234 (min_val_idx == -1 || new_params.perm[min_val_idx] > new_params.perm[j]))
239 new_params.perm[min_val_idx] = i;
241 *params = new_params;
245 Shape *non_flatten_input_shape,
Shape *non_flatten_output_shape,
246 TransposeParams *non_flatten_params)
249 int skip_dims_cnt = 0;
250 size_t flat_size = input_shape.FlatSize();
251 for (
int i = 0; i < params.perm_count; ++i)
253 if (params.perm[i] == i)
255 flat_size /= input_shape.Dims(i);
265 const int new_dims_cnt = params.perm_count - skip_dims_cnt;
266 non_flatten_input_shape->Resize(new_dims_cnt);
267 non_flatten_output_shape->Resize(new_dims_cnt);
268 non_flatten_params->perm_count = new_dims_cnt;
270 for (
int i = skip_dims_cnt; i < params.perm_count; ++i)
272 non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
273 non_flatten_output_shape->SetDim(i - skip_dims_cnt,
output_shape.Dims(i));
274 non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
276 for (
int i = 0; i < new_dims_cnt; ++i)
278 int min_val_idx = -1;
279 for (
int j = 0; j < new_dims_cnt; ++j)
281 if (non_flatten_params->perm[j] >= i &&
282 (min_val_idx == -1 ||
283 non_flatten_params->perm[min_val_idx] > non_flatten_params->perm[j]))
288 non_flatten_params->perm[min_val_idx] = i;
306 const int d0 = input_shape.
DimsData()[0];
307 const int d1 = input_shape.
DimsData()[1];
308 const int kLines = 4;
309 const int kSkipSize = (kLines - 1) * d1;
311 const T *input = input_data;
314 for (; i <= d0 - kLines; i += kLines)
316 T *output = output_data + i;
318 const T *input_ptr = input;
328 for (; j <= d1 - kLines; j += kLines)
331 const T a00 = input_ptr[0];
332 const T a01 = input_ptr[1];
333 const T a02 = input_ptr[2];
334 const T a03 = input_ptr[3];
336 const T a10 = input_ptr[0];
337 const T a11 = input_ptr[1];
338 const T a12 = input_ptr[2];
339 const T a13 = input_ptr[3];
341 const T a20 = input_ptr[0];
342 const T a21 = input_ptr[1];
343 const T a22 = input_ptr[2];
344 const T a23 = input_ptr[3];
346 const T a30 = input_ptr[0];
347 const T a31 = input_ptr[1];
348 const T a32 = input_ptr[2];
349 const T a33 = input_ptr[3];
383 for (
int p = 0;
p < kLines; ++
p)
385 for (
int q = 0; q < d1 - j; ++q)
387 *(output + q * d0 +
p) = *(input +
p * d1 + q);
390 input += (d1 - j) + kSkipSize;
395 T *output = output_data + i;
396 for (
int j = 0; j < d1; ++j)
409 const T *input_data,
const Shape &, T *output_data)
412 s2 = input_shape.
Dims(1);
413 s3 = input_shape.
Dims(2);
419 if (params.
perm[0] == 2)
423 else if (params.
perm[1] == 2)
432 if (params.
perm[0] == 1)
436 else if (params.
perm[1] == 1)
445 if (params.
perm[0] == 0)
449 else if (params.
perm[1] == 0)
459 o_s[0] = input_shape.
Dims(params.
perm[0]);
460 o_s[1] = input_shape.
Dims(params.
perm[1]);
461 o_s[2] = input_shape.
Dims(params.
perm[2]);
463 for (
int i1 = 0; i1 < o_s[0]; ++i1)
465 for (
int i2 = 0; i2 < o_s[1]; ++i2)
467 for (
int i3 = 0; i3 < o_s[2]; ++i3)
469 const int i = i1 * p1 + i2 * p2 + i3 * p3;
470 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
471 output_data[o] = input_data[i];
484 if (IsTranspose2DApplicable(params, input_shape, &dim0, &dim1))
513 const T *input_data,
const Shape &unshrunk_output_shape, T *output_data)
517 assert(output_size <= 6);
518 assert(output_size == unshrunk_params.
perm_count);
520 Shape shrunk_input_shape =
Shape(unshrunk_input_shape);
522 Shape shrunk_output_shape =
Shape(unshrunk_output_shape);
528 RemoveOneSizeDimensions(&shrunk_input_shape, &shrunk_output_shape, &shrunk_params);
533 bool identical =
true;
534 for (
int i = 0; i < shrunk_params.
perm_count; ++i)
537 if (shrunk_params.
perm[i] != i)
546 memcpy(output_data, input_data, unshrunk_input_shape.
FlatSize() *
sizeof(T));
551 if (shrunk_params.
perm[0] == 0 && output_size >= 3)
554 Shape non_flatten_input_shape;
555 Shape non_flatten_output_shape;
557 const int total_size = shrunk_input_shape.
FlatSize();
559 const int non_flatten_size =
560 Flatten(shrunk_input_shape, shrunk_output_shape, shrunk_params,
562 &non_flatten_input_shape, &non_flatten_output_shape, &non_flatten_params);
563 assert(non_flatten_params.
perm[0] != 0);
565 for (
int i = 0; i < total_size; i += non_flatten_size)
567 TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
568 non_flatten_output_shape, output_data + i);
574 TransposeImpl(shrunk_params, shrunk_input_shape, input_data, shrunk_output_shape,
int32_t DimensionsCount() const
int32_t Dims(int i) const
const luci_interpreter::RuntimeShape output_shape
void Transpose(const TransposeParams ¶ms, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
void TransposeImpl(const TransposeParams ¶ms, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
void TransposeImpl(const TransposeParams ¶ms, const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
void Transpose3D(const TransposeParams ¶ms, const Shape &input_shape, const T *input_data, const Shape &, T *output_data)
void Transpose2D(const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape, const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
void optimized_ops_preload_l1_keep(const T *ptr)