ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Gather.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Gather.h"
19#include "kernels/Utils.h"
20#include "PALGather.h"
21
22#include <stdexcept>
23#include <cassert>
24
25namespace luci_interpreter
26{
27
28namespace kernels
29{
30
31Gather::Gather(const Tensor *params, const Tensor *indices, Tensor *output,
32 const GatherParams &gparams)
33 : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
34{
35}
36
38{
39 if (params()->element_type() == DataType::FLOAT32)
40 {
41 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
42 }
43 else
44 {
45 throw std::runtime_error("luci-intp Gather(1) Unsupported type.");
46 }
47
48 LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
49 indices()->element_type() == DataType::S64);
50
51 // refer tensorflow/lite/kernels/gather.cc
52
53 const Shape &params_shape = params()->shape();
54 const Shape &indices_shape = indices()->shape();
55
56 int axis = _params.axis;
57 if (axis < 0)
58 {
59 axis += params_shape.num_dims();
60 }
61 LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
62
63 int batch_dims = _params.batch_dims;
64 // batch_dims should be in range: [-rank(indices), rank(indices)].
65 // Negative batch_dims is added with rank of positions.
66 if (batch_dims < 0)
67 {
68 batch_dims += indices_shape.num_dims();
69 }
70 LUCI_INTERPRETER_CHECK(batch_dims <= axis);
71 LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
72 LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
73 for (int i = 0; i < batch_dims; ++i)
74 {
75 LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
76 }
77
78 const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
79
80 Shape output_shape(num_dimensions);
81 int output_index = 0;
82 for (int i = 0; i < axis; ++i)
83 {
84 output_shape.dim(output_index++) = params_shape.dim(i);
85 }
86 for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
87 {
88 output_shape.dim(output_index++) = indices_shape.dim(i);
89 }
90 for (int i = axis + 1; i < params_shape.num_dims(); ++i)
91 {
92 output_shape.dim(output_index++) = params_shape.dim(i);
93 }
95}
96
97void Gather::execute() const
98{
99 switch (params()->element_type())
100 {
101 case DataType::FLOAT32:
102 evalFloat();
103 break;
104 default:
105 throw std::runtime_error("luci-intp Gather(2) Unsupported type.");
106 }
107}
108
109void Gather::evalFloat() const
110{
111 assert(indices()->element_type() == DataType::S32 || indices()->element_type() == DataType::S64);
112
113 const auto params_data = getTensorData<float>(params());
114 auto output_data = getTensorData<float>(output());
115
116 tflite::GatherParams tparams;
117 tparams.axis = _params.axis;
118 tparams.batch_dims = _params.batch_dims;
119
120 if (indices()->element_type() == DataType::S32)
121 {
122 const auto indices_data = getTensorData<int32_t>(indices());
123
124 luci_interpreter_pal::Gather<float, int32_t>(tparams, getTensorShape(params()), params_data,
125 getTensorShape(indices()), indices_data,
126 getTensorShape(output()), output_data);
127 }
128 else
129 {
130 const auto indices_data = getTensorData<int64_t>(indices());
131
132 luci_interpreter_pal::Gather<float, int64_t>(tparams, getTensorShape(params()), params_data,
133 getTensorShape(indices()), indices_data,
134 getTensorShape(output()), output_data);
135 }
136}
137
138} // namespace kernels
139} // namespace luci_interpreter
int32_t dim(int i) const
Definition Tensor.h:41
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const Tensor * indices() const
Definition Gather.h:34
Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
Definition Gather.cpp:31
const Tensor * params() const
Definition Gather.h:33
void execute() const override
Definition Gather.cpp:97
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194