ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Concatenation.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Builders.h"
19#include "kernels/Utils.h"
20
21#include "PALConcatenation.h"
22
23namespace luci_interpreter
24{
25
26namespace
27{
28
29template <typename T>
30void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
31{
32 const auto output_index = cur_op->outputs()->operator[](0);
33
34 assert(output_index != -1);
35
36 auto output = runtime_graph->getCircleTensorByIndex(output_index);
37
38 const auto *options = cur_op->builtin_options_as_ConcatenationOptions();
39
40 int axis = options->axis();
41 if (axis < 0)
42 axis += Tensor::num_dims(output);
43
44 const auto input_sizes = cur_op->inputs()->size();
45
46 std::vector<const T *> all_input_data;
47 std::vector<luci_interpreter::RuntimeShape> all_shape;
48 std::vector<luci_interpreter::RuntimeShape *> all_shape_ptr;
49
50 for (int32_t i = 0; i < input_sizes; ++i)
51 {
52 auto input_index = cur_op->inputs()->operator[](i);
53 const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
54
55 const auto *tensor_data = runtime_graph->getDataByTensor(tensor);
56 if (tensor_data == nullptr)
57 tensor_data = runtime_graph->getConstDataByTensor(tensor);
58
59 auto *data = reinterpret_cast<const T *>(tensor_data);
60
61 auto runtime_shape = kernels::getTensorRuntimeShape(tensor, runtime_graph);
62
63 all_input_data.push_back(data);
64 all_shape.push_back(runtime_shape);
65 }
66
67 for (luci_interpreter::RuntimeShape &shape : all_shape)
68 {
69 all_shape_ptr.push_back(&shape);
70 }
71
72 auto *output_data = reinterpret_cast<T *>(runtime_graph->getDataByTensor(output));
73
75 params.axis = axis;
76 params.inputs_count = all_shape.size();
77 luci_interpreter_pal::Concatenation(params, all_shape_ptr.data(), all_input_data.data(),
78 kernels::getTensorShape(output), output_data);
79}
80
81} // namespace
82
83void configure_kernel_CircleConcatenation(const circle::Operator *cur_op,
84 BaseRuntimeGraph *runtime_graph)
85{
86 const int num_inputs = cur_op->inputs()->size();
87 LUCI_INTERPRETER_CHECK(num_inputs > 0);
88
89 auto input_index = cur_op->inputs()->operator[](0);
90 auto output_index = cur_op->outputs()->operator[](0);
91
92 assert(input_index != -1);
93 assert(output_index != -1);
94
95 const auto *t0 = runtime_graph->getCircleTensorByIndex(input_index);
96 const auto *output = runtime_graph->getCircleTensorByIndex(output_index);
97
98 const auto *params = cur_op->builtin_options_as_ConcatenationOptions();
99
100 // TODO: Support concat with fused activation function
101 LUCI_INTERPRETER_CHECK(luci_actfunc(params->fused_activation_function()) == FusedActFunc::NONE);
102
103 int axis = params->axis();
104 if (axis < 0)
105 axis += Tensor::num_dims(t0);
106 LUCI_INTERPRETER_CHECK(axis >= 0 && axis < Tensor::num_dims(t0));
107
108 for (int i = 1; i < num_inputs; ++i)
109 {
110 input_index = cur_op->inputs()->operator[](i);
111 const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
112 LUCI_INTERPRETER_CHECK(Tensor::element_type(tensor) == Tensor::element_type(t0));
113 LUCI_INTERPRETER_CHECK(Tensor::num_dims(tensor) == Tensor::num_dims(t0));
114 }
115
116#ifndef DIS_QUANT
117 // If input tensors are INT8 type then quantization parameters of all input tensors and the output
118 // should be the same
119 for (int i = 1; i < num_inputs; ++i)
120 {
121 input_index = cur_op->inputs()->operator[](i);
122 const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
123 if (Tensor::element_type(tensor) == DataType::S8)
124 {
125 LUCI_INTERPRETER_CHECK(Tensor::quantized_dimension(tensor) ==
126 Tensor::quantized_dimension(output));
127
128 LUCI_INTERPRETER_CHECK(Tensor::zero_points(tensor).size() == Tensor::scales(tensor).size());
129 LUCI_INTERPRETER_CHECK(Tensor::zero_points(tensor) == Tensor::zero_points(output));
130 LUCI_INTERPRETER_CHECK(Tensor::scales(tensor) == Tensor::scales(output));
131 }
132 }
133#endif // DIS_QUANT
134}
135
136void execute_kernel_CircleConcatenation(const circle::Operator *cur_op,
137 BaseRuntimeGraph *runtime_graph)
138{
139 int num_inputs = cur_op->inputs()->size();
140 LUCI_INTERPRETER_CHECK(num_inputs > 0);
141
142 const auto input_index = cur_op->inputs()->operator[](0);
143 assert(input_index != -1);
144 const auto *t0 = runtime_graph->getCircleTensorByIndex(input_index);
145
146 switch (Tensor::element_type(t0))
147 {
148#ifndef DIS_FLOAT
149 case DataType::FLOAT32:
150 evalGeneric<float>(cur_op, runtime_graph);
151 break;
152#endif // DIS_FLOAT
153#ifndef DIS_QUANT
154 case DataType::S8:
155 evalGeneric<int8_t>(cur_op, runtime_graph);
156 break;
157#endif // DIS_QUANT
158 case DataType::S32:
159 evalGeneric<int32_t>(cur_op, runtime_graph);
160 break;
161 case DataType::S64:
162 evalGeneric<int64_t>(cur_op, runtime_graph);
163 break;
164 default:
165 assert(false && "Unsupported type.");
166 }
167}
168
169} // namespace luci_interpreter
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const T * data(const std::vector< T, Alloc > &v)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
Definition Utils.cpp:29
void Concatenation(const ConcatenationParams &params, const luci_interpreter::RuntimeShape *const *input_shapes, const Scalar *const *input_data, const luci_interpreter::RuntimeShape &output_shape, Scalar *output_data)
void execute_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
RuntimeGraph BaseRuntimeGraph
void configure_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)
int32_t size[5]
Definition Slice.cpp:35