ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Sum.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Sum.h"
19
20#include "kernels/Utils.h"
21
22#include <tensorflow/lite/kernels/internal/reference/reduce.h>
23
24#include <stdexcept>
25
26namespace luci_interpreter
27{
28namespace kernels
29{
30
31// Returns the number of axes that will be reduced. Removes duplicates.
32static int getAxisReductionCount(const int32_t *axes_data, int num_axes, int input_num_dims)
33{
34 int reduction_count = num_axes;
35 for (int i = 0; i < num_axes; ++i)
36 {
37 int current = axes_data[i] >= 0 ? axes_data[i] : axes_data[i] + input_num_dims;
38 assert(current >= 0 && current < input_num_dims);
39 for (int j = 0; j < i; j++)
40 {
41 int previous = axes_data[j] >= 0 ? axes_data[j] : axes_data[j] + input_num_dims;
42 // This checks for duplicate axis
43 if (current == previous)
44 {
45 --reduction_count;
46 break;
47 }
48 }
49 }
50 return reduction_count;
51}
52
53static Shape getOutputShape(const Shape &input_shape, const int32_t *axes_data, int num_axes,
54 bool keep_dims)
55{
56 int input_num_dims = input_shape.num_dims();
57 if (input_num_dims == 0)
58 {
59 return Shape(0);
60 }
61
62 if (keep_dims)
63 {
64 Shape output_shape(input_num_dims);
65 for (int idx = 0; idx < input_num_dims; ++idx)
66 {
67 bool is_axis = false;
68 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
69 {
70 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
71 {
72 is_axis = true;
73 break;
74 }
75 }
76 if (is_axis)
77 {
78 output_shape.dim(idx) = 1;
79 }
80 else
81 {
82 output_shape.dim(idx) = input_shape.dim(idx);
83 }
84 }
85 return output_shape;
86 }
87 else
88 {
89 int num_reduce_axes = getAxisReductionCount(axes_data, num_axes, input_num_dims);
90 Shape output_shape(input_num_dims - num_reduce_axes);
91 int num_skip_axes = 0;
92 for (int idx = 0; idx < input_num_dims; ++idx)
93 {
94 bool is_axis = false;
95 for (int axis_idx = 0; axis_idx < num_axes; ++axis_idx)
96 {
97 if (axes_data[axis_idx] == idx || axes_data[axis_idx] + input_num_dims == idx)
98 {
99 ++num_skip_axes;
100 is_axis = true;
101 break;
102 }
103 }
104 if (!is_axis)
105 {
106 output_shape.dim(idx - num_skip_axes) = input_shape.dim(idx);
107 }
108 }
109 return output_shape;
110 }
111}
112
113Sum::Sum(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index,
114 Tensor *resolved_axes, const ReducerParams &params)
115 : KernelWithParams<ReducerParams>({input, axes}, {output, temp_index, resolved_axes}, params)
116{
117}
118
120{
121 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
122 LUCI_INTERPRETER_CHECK(axes()->element_type() == DataType::S32);
123
124 const Shape &input_shape = input()->shape();
125 int input_num_dims = input_shape.num_dims();
126
127 const auto *axes_data = getTensorData<int32_t>(axes());
128 int num_axes = axes()->shape().num_elements();
129 LUCI_INTERPRETER_CHECK(num_axes <= 4);
130
131 // We compute shapes of outputs in configure, assuming that outputs have
132 // static shape
133 // TODO Support dynamic shape
134 Shape output_shape = getOutputShape(input_shape, axes_data, num_axes, _params.keep_dims);
136
137 auto temp_index = getOutputTensors()[1];
138 auto resolved_axes = getOutputTensors()[2];
139
140 temp_index->resize(Shape(input_num_dims));
141 resolved_axes->resize(Shape(num_axes));
142}
143
144void Sum::execute() const
145{
146 switch (input()->element_type())
147 {
148 case DataType::FLOAT32:
149 evalFloat();
150 break;
151 default:
152 throw std::runtime_error("luci-intp Sum Unsupported type.");
153 }
154}
155
156void Sum::evalFloat() const
157{
158 const auto *axes_data = getTensorData<int32_t>(axes());
159 int num_axes = axes()->shape().num_elements();
160
161 auto temp_index = getOutputTensors()[1];
162 auto resolved_axes = getOutputTensors()[2];
163
164 int num_resolved_axis = 0;
166 tflite::reference_ops::ResolveAxis(input()->shape().num_dims(), axes_data, num_axes,
167 getTensorData<int>(resolved_axes), &num_resolved_axis));
168
169 float init_value = 0.0;
170 tflite::reference_ops::ReduceGeneric<float>(
171 getTensorData<float>(input()), getTensorShape(input()).DimsData(), input()->shape().num_dims(),
172 getTensorData<float>(output()), getTensorShape(output()).DimsData(),
173 output()->shape().num_dims(), axes_data, num_axes, _params.keep_dims,
174 getTensorData<int>(temp_index), getTensorData<int>(resolved_axes), init_value,
175 [](const float current, const float in) -> float { return current + in; });
176}
177
178} // namespace kernels
179} // namespace luci_interpreter
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
int32_t num_elements() const
Definition Tensor.h:53
int num_dims() const
Definition Tensor.h:39
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
Sum(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index, Tensor *resolved_axes, const ReducerParams &params)
Definition Sum.cpp:113
Tensor * output() const
Definition Sum.h:39
void configure() override
Definition Sum.cpp:119
const Tensor * axes() const
Definition Sum.h:38
const Tensor * input() const
Definition Sum.h:37
void execute() const override
Definition Sum.cpp:144
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
Definition Shape.h:28