ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Pack.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Pack.h"
19#include "kernels/Utils.h"
20
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
22
23#include <stdexcept>
24
25namespace luci_interpreter
26{
27namespace kernels
28{
29
30Pack::Pack(std::vector<const Tensor *> inputs, Tensor *output, const PackParams &params)
31 : KernelWithParams<PackParams>(std::move(inputs), {output}, params)
32{
33}
34
36{
37 LUCI_INTERPRETER_CHECK(_inputs.size() == static_cast<uint32_t>(params().values_count));
38 const Tensor *t0 = _inputs[0];
39 const int dimension_size = t0->shape().num_dims() + 1;
40 int axis = params().axis;
41 if (axis < 0)
42 {
43 axis += dimension_size;
44 }
45 LUCI_INTERPRETER_CHECK(axis >= 0 && axis <= t0->shape().num_dims());
46
47 if (t0->element_type() != DataType::S32 && t0->element_type() != DataType::FLOAT32 &&
48 t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 &&
49 t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64)
50 {
51 throw std::runtime_error("luci-intp Pack(1) Unsupported type.");
52 }
53
54 for (uint32_t i = 1; i < _inputs.size(); ++i)
55 {
56 const Tensor *tensor = _inputs[i];
57 LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
58 LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
59 for (int d = 0; d < t0->shape().num_dims(); ++d)
60 {
61 LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
62 }
63 }
64
65 Shape output_shape(dimension_size);
66 int i = 0;
67 for (int index = 0; index < dimension_size; ++index)
68 {
69 if (index == axis)
70 {
71 output_shape.dim(index) = params().values_count;
72 }
73 else
74 {
75 output_shape.dim(index) = t0->shape().dim(i++);
76 }
77 }
78
79 if (t0->element_type() == DataType::U8 || t0->element_type() == DataType::S8 ||
80 t0->element_type() == DataType::S16)
81 {
82 LUCI_INTERPRETER_CHECK(output()->zero_point() == t0->zero_point());
83 LUCI_INTERPRETER_CHECK(output()->scale() == t0->scale());
84 // Guarantee input/output quantization params match as we do not support
85 // packing quantized tensors.
86 for (int i = 0; i < params().values_count; i++)
87 {
88 LUCI_INTERPRETER_CHECK(_inputs[i]->zero_point() == t0->zero_point());
89 LUCI_INTERPRETER_CHECK(_inputs[i]->scale() == t0->scale());
90 }
91 }
92
94}
95
96void Pack::execute() const
97{
98 switch (_inputs[0]->element_type())
99 {
100 case DataType::FLOAT32:
101 evalGeneric<float>();
102 break;
103 case DataType::U8:
104 evalGeneric<uint8_t>();
105 break;
106 case DataType::S8:
107 evalGeneric<int8_t>();
108 break;
109 case DataType::S16:
110 evalGeneric<int16_t>();
111 break;
112 case DataType::S32:
113 evalGeneric<int32_t>();
114 break;
115 case DataType::S64:
116 evalGeneric<int64_t>();
117 break;
118 default:
119 throw std::runtime_error("luci-intp Pack(2) Unsupported type.");
120 }
121}
122
123template <typename T> void Pack::evalGeneric() const
124{
125 const Tensor *t0 = _inputs[0];
126 const int dimension_size = t0->shape().num_dims() + 1;
127 int axis = params().axis;
128 if (axis < 0)
129 {
130 axis += dimension_size;
131 }
132
133 VectorOfTensors<T, true> inputs(_inputs);
134 tflite::PackParams params{};
135 params.axis = axis;
136 params.inputs_count = _inputs.size();
137 tflite::reference_ops::Pack<T>(params, inputs.shapes(), inputs.data(), getTensorShape(output()),
138 getTensorData<T>(output()));
139}
140
141} // namespace kernels
142} // namespace luci_interpreter
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
float scale() const
Definition Tensor.h:109
DataType element_type() const
Definition Tensor.h:105
int32_t zero_point() const
Definition Tensor.h:115
Tensor * output() const
Definition Pack.h:34
void configure() override
Definition Pack.cpp:35
Pack(std::vector< const Tensor * > inputs, Tensor *output, const PackParams &params)
Definition Pack.cpp:30
void execute() const override
Definition Pack.cpp:96
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194