ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Gather.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Gather.h"
19#include "kernels/Utils.h"
20#include "PALGather.h"
21
22#include <stdexcept>
23#include <cassert>
24
25namespace luci_interpreter
26{
27
28namespace kernels
29{
30
31Gather::Gather(const Tensor *params, const Tensor *indices, Tensor *output,
32 const GatherParams &gparams)
33 : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
34{
35}
36
38{
39 if (params()->element_type() == DataType::FLOAT32)
40 {
41 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
42 }
43 else if (params()->element_type() == DataType::S32)
44 {
45 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S32);
46 }
47 else
48 {
49 throw std::runtime_error("luci-intp Gather(1) Unsupported type.");
50 }
51
52 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
53 indices()->element_type() == DataType::S64);
54
55 // refer tensorflow/lite/kernels/gather.cc
56
57 const Shape &params_shape = params()->shape();
58 Shape indices_shape = indices()->shape();
59 {
60 // scalar index is treated as a tensor with the shape of [1]
61 if (indices_shape.num_dims() == 0)
62 {
63 indices_shape = Shape({1});
64 }
65 }
66
67 int axis = _params.axis;
68 if (axis < 0)
69 {
70 axis += params_shape.num_dims();
71 }
72 LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
73
74 int batch_dims = _params.batch_dims;
75 // batch_dims should be in range: [-rank(indices), rank(indices)].
76 // Negative batch_dims is added with rank of positions.
77 if (batch_dims < 0)
78 {
79 batch_dims += indices_shape.num_dims();
80 }
81 LUCI_INTERPRETER_CHECK(batch_dims <= axis);
82 LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
83 LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
84 for (int i = 0; i < batch_dims; ++i)
85 {
86 LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
87 }
88
89 const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
90
91 Shape output_shape(num_dimensions);
92 int output_index = 0;
93 for (int i = 0; i < axis; ++i)
94 {
95 output_shape.dim(output_index++) = params_shape.dim(i);
96 }
97 for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
98 {
99 output_shape.dim(output_index++) = indices_shape.dim(i);
100 }
101 for (int i = axis + 1; i < params_shape.num_dims(); ++i)
102 {
103 output_shape.dim(output_index++) = params_shape.dim(i);
104 }
106}
107
108void Gather::execute() const
109{
110 switch (params()->element_type())
111 {
112 case DataType::FLOAT32:
113 eval<float>();
114 break;
115 case DataType::S32:
116 eval<int32_t>();
117 break;
118 default:
119 throw std::runtime_error("luci-intp Gather(2) Unsupported type.");
120 }
121}
122
123template <typename T> void Gather::eval() const
124{
125 assert(indices()->element_type() == DataType::S32 || indices()->element_type() == DataType::S64);
126
127 const auto params_data = getTensorData<T>(params());
128 auto output_data = getTensorData<T>(output());
129
130 tflite::GatherParams tparams;
131 tparams.axis = _params.axis;
132 tparams.batch_dims = _params.batch_dims;
133
134 if (indices()->element_type() == DataType::S32)
135 {
136 const auto indices_data = getTensorData<int32_t>(indices());
137
138 luci_interpreter_pal::Gather<T, int32_t>(tparams, getTensorShape(params()), params_data,
139 getTensorShape(indices()), indices_data,
140 getTensorShape(output()), output_data);
141 }
142 else
143 {
144 const auto indices_data = getTensorData<int64_t>(indices());
145
146 luci_interpreter_pal::Gather<T, int64_t>(tparams, getTensorShape(params()), params_data,
147 getTensorShape(indices()), indices_data,
148 getTensorShape(output()), output_data);
149 }
150}
151
152} // namespace kernels
153} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const Tensor * indices() const
Definition Gather.h:34
Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
Definition Gather.cpp:31
const Tensor * params() const
Definition Gather.h:33
void execute() const override
Definition Gather.cpp:108
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194