ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Concatenation.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/concatenation.h>
22
23#include <stdexcept>
24
25namespace luci_interpreter
26{
27namespace kernels
28{
29
30Concatenation::Concatenation(std::vector<const Tensor *> inputs, Tensor *output,
31 const ConcatenationParams &params)
32 : KernelWithParams<ConcatenationParams>(std::move(inputs), {output}, params)
33{
34}
35
37{
38 const int num_inputs = _inputs.size();
39 LUCI_INTERPRETER_CHECK(num_inputs > 0);
40 const Tensor *t0 = _inputs[0];
41
42 // TODO: Support concat with fused activation function
44
45 int axis = _params.axis;
46 if (axis < 0)
47 axis += t0->shape().num_dims();
48 LUCI_INTERPRETER_CHECK(axis >= 0 && axis < t0->shape().num_dims());
49
50 int32_t sum_axis = t0->shape().dim(axis);
51 for (int i = 1; i < num_inputs; ++i)
52 {
53 const Tensor *tensor = _inputs[i];
54 LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
55 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
56 for (int d = 0; d < t0->shape().num_dims(); ++d)
57 {
58 if (d == axis)
59 {
60 sum_axis += tensor->shape().dim(axis);
61 }
62 else
63 {
64 LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
65 }
66 }
67 }
68
69 Shape output_shape = t0->shape();
70 output_shape.dim(axis) = sum_axis;
71
72 // If input tensors are INT8 type then quantization parameters of all input tensors and the output
73 // should be the same
74 for (auto current_tensor : _inputs)
75 {
76 if (current_tensor->element_type() == DataType::S8)
77 {
78 LUCI_INTERPRETER_CHECK(current_tensor->quantized_dimension() ==
79 output()->quantized_dimension());
80
81 LUCI_INTERPRETER_CHECK(current_tensor->zero_points().size() ==
82 current_tensor->scales().size());
83 LUCI_INTERPRETER_CHECK(current_tensor->zero_points() == output()->zero_points());
84 LUCI_INTERPRETER_CHECK(current_tensor->scales() == output()->scales());
85 }
86 }
88}
89
91{
92 switch (_inputs[0]->element_type())
93 {
94 case DataType::FLOAT32:
95 evalGeneric<float>();
96 break;
97 case DataType::U8:
98 evalQuantized();
99 break;
100 case DataType::S8:
101 evalGeneric<int8_t>();
102 break;
103 case DataType::S32:
104 evalGeneric<int32_t>();
105 break;
106 case DataType::S64:
107 evalGeneric<int64_t>();
108 break;
109 default:
110 throw std::runtime_error("luci-intp Concatenation Unsupported type.");
111 }
112}
113
114template <typename T> void Concatenation::evalGeneric() const
115{
116 int axis = _params.axis;
117 if (axis < 0)
118 axis += output()->shape().num_dims();
119
121 tflite::ConcatenationParams params{};
122 params.axis = axis;
123 params.inputs_count = _inputs.size();
124 tflite::reference_ops::Concatenation(params, inputs.shapes(), inputs.data(),
125 getTensorShape(output()), getTensorData<T>(output()));
126}
127
128void Concatenation::evalQuantized() const
129{
130 int axis = _params.axis;
131 if (axis < 0)
132 axis += output()->shape().num_dims();
133
134 VectorOfQuantizedTensors<true> inputs(_inputs);
135 tflite::ConcatenationParams params{};
136 params.axis = axis;
137 params.input_zeropoint = inputs.zero_point();
138 params.input_scale = inputs.scale();
139 params.inputs_count = _inputs.size();
140 params.output_zeropoint = output()->zero_point();
141 params.output_scale = output()->scale();
142
143 tflite::reference_ops::ConcatenationWithScaling(params, inputs.shapes(), inputs.data(),
145 getTensorData<uint8_t>(output()));
146}
147
148} // namespace kernels
149} // namespace luci_interpreter
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52
const ConcatenationParams & params() const
Definition Kernel.h:67
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
float scale() const
Definition Tensor.h:109
DataType element_type() const
Definition Tensor.h:105
int32_t zero_point() const
Definition Tensor.h:115
Concatenation(std::vector< const Tensor * > inputs, Tensor *output, const ConcatenationParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194