ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Transpose.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/Transpose.h"
18
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/transpose.h>
22
23#include <stdexcept>
24
25namespace luci_interpreter
26{
27
28namespace kernels
29{
30
31Transpose::Transpose(const Tensor *input, const Tensor *perm, Tensor *output)
32 : Kernel({input, perm}, {output})
33{
34}
35
37{
38 // Transpose op only supports 1D-4D input arrays.
39 int dims = input()->shape().num_dims();
40 const int32_t *perm_data = getTensorData<int32_t>(perm());
41
42 assert(input()->shape().num_dims() <= 4);
43 assert(input()->element_type() == output()->element_type());
44
45 assert(perm()->shape().num_dims() == 1);
46 assert(perm()->shape().dim(0) == dims);
47
48 Shape output_shape(dims);
49 for (int i = 0; i < dims; i++)
50 {
51 assert(perm_data[i] < dims && perm_data[i] >= 0);
52 output_shape.dim(i) = input()->shape().dim(perm_data[i]);
53 }
54
56}
57
59{
60 tflite::TransposeParams params{};
61 const int32_t *perm_data = getTensorData<int32_t>(perm());
62 const int32_t size = perm()->shape().dim(0);
63 params.perm_count = size;
64 for (int i = 0; i < size; i++)
65 params.perm[i] = perm_data[i];
66 switch (input()->element_type())
67 {
68 case DataType::FLOAT32:
69 tflite::reference_ops::Transpose(params, getTensorShape(input()),
70 getTensorData<float>(input()), getTensorShape(output()),
71 getTensorData<float>(output()));
72 break;
73 case DataType::S64:
74 tflite::reference_ops::Transpose(params, getTensorShape(input()),
75 getTensorData<int64_t>(input()), getTensorShape(output()),
76 getTensorData<int64_t>(output()));
77 break;
78 case DataType::U8:
79 tflite::reference_ops::Transpose(params, getTensorShape(input()),
80 getTensorData<uint8_t>(input()), getTensorShape(output()),
81 getTensorData<uint8_t>(output()));
82 break;
83 default:
84 throw std::runtime_error("luci-intp Transpose Unsupported type.");
85 }
86}
87
88} // namespace kernels
89} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const Tensor * input() const
Definition Transpose.h:33
Transpose(const Tensor *input, const Tensor *perm, Tensor *output)
Definition Transpose.cpp:31
const Tensor * perm() const
Definition Transpose.h:34
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
int32_t size[5]
Definition Slice.cpp:35