ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Mean.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Builders.h"
18#include "kernels/Utils.h"
19#include "TISOKernel.h"
20
21#include "PALMean.h"
22
23#include <cassert>
24
25namespace luci_interpreter
26{
27namespace
28{
29const int kMaxNumberOfAxis = 5;
30const int kMaxNumberOfReducedAxis = 2;
31
32void ResolveAxis(const int *axis_data, int axis_count, luci_interpreter_pal::MeanParams *op_params)
33{
34 int i = 0;
35 for (; i < axis_count; ++i)
36 {
37 op_params->axis[i] = static_cast<int16_t>(axis_data[i]);
38 }
39 for (; i < 4; ++i)
40 {
41 op_params->axis[i] = 1;
42 }
43 op_params->axis_count = axis_count;
44}
45
46} // namespace
47
48void configure_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
49{
50 kernels::TISOKernel kernel(cur_op, runtime_graph);
51
52 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
53 Tensor::element_type(kernel.output()));
54 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
55
56 const int32_t axis_value =
57 kernels::getTensorData<int>(runtime_graph->getConstDataByTensor(kernel.input2()))[0];
58 LUCI_INTERPRETER_CHECK(axis_value >= 0);
59}
60
61void execute_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
62{
63 kernels::TISOKernel kernel(cur_op, runtime_graph);
64 kernels::TISOData tiso_data = kernel.readData();
65
66 const auto *input = kernel.input1();
67 const auto *axis = kernel.input2();
68 const auto *output = kernel.output();
69
70 const auto *options = cur_op->builtin_options_as_ReducerOptions();
71
72 int num_axis = static_cast<int>(Tensor::num_elements(axis));
73 int temp_index[kMaxNumberOfAxis];
74 int resolved_axis[kMaxNumberOfReducedAxis];
75
76 switch (Tensor::element_type(kernel.input1()))
77 {
78#ifndef DIS_FLOAT
79 case DataType::FLOAT32:
80 {
82 ResolveAxis(kernels::getTensorData<int>(tiso_data.input2_data), num_axis, &op_params);
83
84 // Special case mean implementation exists for 4D mean across axes 1
85 // and 2.
86 bool special_case_4d_axes_1_and_2 = Tensor::num_dims(input) == 4 &&
87 op_params.axis_count == 2 &&
88 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
89 (op_params.axis[0] == 2 && op_params.axis[1] == 1));
90
91 // Defer to specialized implementation for 4D Mean across axes 1 & 2.
92 if (options->keep_dims() && special_case_4d_axes_1_and_2)
93 {
95 kernels::getTensorData<float>(tiso_data.input1_data),
97 kernels::getTensorData<float>(tiso_data.output_data));
98 }
99 else
100 {
102 kernels::getTensorData<float>(tiso_data.input1_data),
103 reinterpret_cast<const int *>(wrap(input->shape()).data()), Tensor::num_dims(input),
104 kernels::getTensorData<float>(tiso_data.output_data),
105 reinterpret_cast<const int *>(wrap(output->shape()).data()), Tensor::num_dims(output),
106 kernels::getTensorData<int>(tiso_data.input2_data), num_axis, options->keep_dims(),
107 temp_index, resolved_axis, kernels::getTensorData<float>(tiso_data.output_data));
108 }
109 }
110 break;
111#endif // DIS_FLOAT
112 default:
113 assert(false && "Unsupported type");
114 }
115}
116
117} // namespace luci_interpreter
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
Definition TISOKernel.h:62
const circle::Tensor * input2() const
Definition TISOKernel.h:61
const circle::Tensor * input1() const
Definition TISOKernel.h:60
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
bool Mean(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *output_dims, const int output_num_dims, const int *axis, const int num_axis_dimensions, bool, int *temp_index, int *resolved_axis, U *temp_sum)
Definition PALMean.h:108
void execute_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Mean.cpp:61
void configure_kernel_CircleMean(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Mean.cpp:48
VectorWrapper< T > wrap(const flatbuffers::Vector< T > *vec)
bool ResolveAxis(const int num_dims, const std::vector< int > &axes, int *out_axis, int *out_num_axis)
Definition Reduce.h:169