18#ifndef __NNFW_CKER_TRANSPOSE_H__
19#define __NNFW_CKER_TRANSPOSE_H__
34 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
36 const int unextended_output_size = unextended_output_shape.
DimensionsCount();
38 assert(unextended_output_size <= 4);
39 assert(unextended_output_size == params.
perm_count);
40 const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
42 const int input_ext_size = 4 - unextended_input_shape.
DimensionsCount();
43 const int output_ext_size = 4 - unextended_output_size;
48 for (
int i = 0; i < output_ext_size; ++i)
52 for (
int i = 0; i < unextended_output_size; ++i)
54 extended_perm[i + output_ext_size] = params.
perm[i] + input_ext_size;
60 for (
int k = 0; k < 4; k++)
68 for (o[3] = 0; o[3] < out_sizes[3]; o[3]++)
70 i[extended_perm[3]] = o[3];
71 for (o[2] = 0; o[2] < out_sizes[2]; o[2]++)
73 i[extended_perm[2]] = o[2];
74 for (o[1] = 0; o[1] < out_sizes[1]; o[1]++)
76 i[extended_perm[1]] = o[1];
77 for (o[0] = 0; o[0] < out_sizes[0]; o[0]++)
79 i[extended_perm[0]] = o[0];
89 const T *input_data,
const Shape &unextended_output_shape, T *output_data)
97 TransposeImpl<int8_t>(params, unextended_input_shape,
98 reinterpret_cast<const int8_t *
>(input_data), unextended_output_shape,
99 reinterpret_cast<int8_t *
>(output_data));
102 TransposeImpl<int16_t>(params, unextended_input_shape,
103 reinterpret_cast<const int16_t *
>(input_data), unextended_output_shape,
104 reinterpret_cast<int16_t *
>(output_data));
108 TransposeImpl<int32_t>(params, unextended_input_shape,
109 reinterpret_cast<const int32_t *
>(input_data), unextended_output_shape,
110 reinterpret_cast<int32_t *
>(output_data));
113 TransposeImpl<int64_t>(params, unextended_input_shape,
114 reinterpret_cast<const int64_t *
>(input_data), unextended_output_shape,
115 reinterpret_cast<int64_t *
>(output_data));
124bool IsTranspose2DApplicable(
const TransposeParams ¶ms,
const Shape &input_shape,
int *dim0,
127 const int dims_cnt = input_shape.DimensionsCount();
131 *dim0 = input_shape.Dims(0);
132 *dim1 = input_shape.Dims(1);
136 const int first_perm = params.perm[0];
137 for (
int i = 1; i < dims_cnt; ++i)
139 int rebased = params.perm[i] - first_perm;
151 for (
int i = 0; i < dims_cnt; ++i)
155 *dim0 *= input_shape.Dims(i);
159 *dim1 *= input_shape.Dims(i);
167 const int dims_cnt = input_shape->DimensionsCount();
168 assert(params->perm_count == dims_cnt);
170 bool foundOneSizeDim =
false;
171 for (
int i = 0; i < dims_cnt; ++i)
173 if (input_shape->Dims(i) == 1)
175 foundOneSizeDim =
true;
181 if (!foundOneSizeDim)
185 if (input_shape->FlatSize() == 1)
187 input_shape->Resize(1);
188 input_shape->SetDim(0, 1);
191 params->perm_count = 1;
197 int new_dims_cnt = 0;
198 for (
int i = 0; i < dims_cnt; ++i)
200 if (input_shape->Dims(i) == 1)
204 input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
207 input_shape->Resize(new_dims_cnt);
210 TransposeParams new_params;
212 for (
int i = 0; i < dims_cnt; ++i)
218 new_params.perm[new_dims_cnt] = params->perm[i];
223 new_params.perm_count = new_dims_cnt;
225 for (
int i = 0; i < new_dims_cnt; ++i)
227 int min_val_idx = -1;
228 for (
int j = 0; j < new_dims_cnt; ++j)
230 if (new_params.perm[j] >= i &&
231 (min_val_idx == -1 || new_params.perm[min_val_idx] > new_params.perm[j]))
236 new_params.perm[min_val_idx] = i;
238 *params = new_params;
242 Shape *non_flatten_input_shape,
Shape *non_flatten_output_shape,
243 TransposeParams *non_flatten_params)
246 int skip_dims_cnt = 0;
247 size_t flat_size = input_shape.FlatSize();
248 for (
int i = 0; i < params.perm_count; ++i)
250 if (params.perm[i] == i)
252 flat_size /= input_shape.Dims(i);
262 const int new_dims_cnt = params.perm_count - skip_dims_cnt;
263 non_flatten_input_shape->Resize(new_dims_cnt);
264 non_flatten_output_shape->Resize(new_dims_cnt);
265 non_flatten_params->perm_count = new_dims_cnt;
267 for (
int i = skip_dims_cnt; i < params.perm_count; ++i)
269 non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
270 non_flatten_output_shape->SetDim(i - skip_dims_cnt,
output_shape.Dims(i));
271 non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
273 for (
int i = 0; i < new_dims_cnt; ++i)
275 int min_val_idx = -1;
276 for (
int j = 0; j < new_dims_cnt; ++j)
278 if (non_flatten_params->perm[j] >= i &&
279 (min_val_idx == -1 ||
280 non_flatten_params->perm[min_val_idx] > non_flatten_params->perm[j]))
285 non_flatten_params->perm[min_val_idx] = i;
303 const int d0 = input_shape.
DimsData()[0];
304 const int d1 = input_shape.
DimsData()[1];
305 const int kLines = 4;
306 const int kSkipSize = (kLines - 1) * d1;
308 const T *input = input_data;
311 for (; i <= d0 - kLines; i += kLines)
313 T *output = output_data + i;
315 const T *input_ptr = input;
325 for (; j <= d1 - kLines; j += kLines)
328 const T a00 = input_ptr[0];
329 const T a01 = input_ptr[1];
330 const T a02 = input_ptr[2];
331 const T a03 = input_ptr[3];
333 const T a10 = input_ptr[0];
334 const T a11 = input_ptr[1];
335 const T a12 = input_ptr[2];
336 const T a13 = input_ptr[3];
338 const T a20 = input_ptr[0];
339 const T a21 = input_ptr[1];
340 const T a22 = input_ptr[2];
341 const T a23 = input_ptr[3];
343 const T a30 = input_ptr[0];
344 const T a31 = input_ptr[1];
345 const T a32 = input_ptr[2];
346 const T a33 = input_ptr[3];
380 for (
int p = 0; p < kLines; ++p)
382 for (
int q = 0; q < d1 - j; ++q)
384 *(output + q * d0 + p) = *(input + p * d1 + q);
387 input += (d1 - j) + kSkipSize;
392 T *output = output_data + i;
393 for (
int j = 0; j < d1; ++j)
406 const T *input_data,
const Shape &, T *output_data)
409 s2 = input_shape.
Dims(1);
410 s3 = input_shape.
Dims(2);
416 if (params.
perm[0] == 2)
420 else if (params.
perm[1] == 2)
429 if (params.
perm[0] == 1)
433 else if (params.
perm[1] == 1)
442 if (params.
perm[0] == 0)
446 else if (params.
perm[1] == 0)
456 o_s[0] = input_shape.
Dims(params.
perm[0]);
457 o_s[1] = input_shape.
Dims(params.
perm[1]);
458 o_s[2] = input_shape.
Dims(params.
perm[2]);
460 for (
int i1 = 0; i1 < o_s[0]; ++i1)
462 for (
int i2 = 0; i2 < o_s[1]; ++i2)
464 for (
int i3 = 0; i3 < o_s[2]; ++i3)
466 const int i = i1 * p1 + i2 * p2 + i3 * p3;
467 const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3;
468 output_data[o] = input_data[i];
481 if (IsTranspose2DApplicable(params, input_shape, &dim0, &dim1))
510 const T *input_data,
const Shape &unshrunk_output_shape, T *output_data)
514 assert(output_size <= 4);
515 assert(output_size == unshrunk_params.
perm_count);
517 Shape shrunk_input_shape =
Shape(unshrunk_input_shape);
519 Shape shrunk_output_shape =
Shape(unshrunk_output_shape);
525 RemoveOneSizeDimensions(&shrunk_input_shape, &shrunk_output_shape, &shrunk_params);
530 bool identical =
true;
531 for (
int i = 0; i < shrunk_params.
perm_count; ++i)
534 if (shrunk_params.
perm[i] != i)
543 memcpy(output_data, input_data, unshrunk_input_shape.
FlatSize() *
sizeof(T));
548 if (shrunk_params.
perm[0] == 0 && output_size >= 3)
551 Shape non_flatten_input_shape;
552 Shape non_flatten_output_shape;
554 const int total_size = shrunk_input_shape.
FlatSize();
556 const int non_flatten_size =
557 Flatten(shrunk_input_shape, shrunk_output_shape, shrunk_params,
559 &non_flatten_input_shape, &non_flatten_output_shape, &non_flatten_params);
560 assert(non_flatten_params.
perm[0] != 0);
562 for (
int i = 0; i < total_size; i += non_flatten_size)
564 TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
565 non_flatten_output_shape, output_data + i);
571 TransposeImpl(shrunk_params, shrunk_input_shape, input_data, shrunk_output_shape,
int32_t DimensionsCount() const
int32_t Dims(int i) const
const luci_interpreter::RuntimeShape output_shape
void Transpose(const TransposeParams ¶ms, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
void TransposeImpl(const TransposeParams ¶ms, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_output_shape, T *output_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
void TransposeImpl(const TransposeParams ¶ms, const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
void Transpose3D(const TransposeParams ¶ms, const Shape &input_shape, const T *input_data, const Shape &, T *output_data)
void Transpose2D(const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape, const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
void optimized_ops_preload_l1_keep(const T *ptr)