ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
PALReduceCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_TANH_H
19#define LUCI_INTERPRETER_PAL_TANH_H
20
21#include "PALUtils.h"
22
24{
25namespace
26{
27// This method parses the input 'axis' to remove duplicates and handle negative
28// values, and returns a valid 'out_axis'
29inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis,
30 int *out_num_axis)
31{
32 int out_axis[2];
33 *out_num_axis = 0; // Just in case.
34 // Short-circuit axis resolution for scalars; the axis will go unused.
35 if (num_dims == 0)
36 {
37 return true;
38 }
39 // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
40 for (int64_t idx = 0; idx < num_axis; ++idx)
41 {
42 // Handle negative index. A positive index 'p_idx' can be represented as a
43 // negative index 'n_idx' as: n_idx = p_idx-num_dims
44 // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
45 int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
46 if (current < 0 || current >= num_dims)
47 {
48 return false;
49 }
50 bool is_dup = false;
51 for (int j = 0; j < *out_num_axis; ++j)
52 {
53 if (out_axis[j] == current)
54 {
55 is_dup = true;
56 break;
57 }
58 }
59 if (!is_dup)
60 {
61 out_axis[*out_num_axis] = current;
62 *out_num_axis += 1;
63 }
64 }
65 return true;
66}
67
68} // namespace
69
70// Computes the generic value (i.e., sum/max/min/prod) of elements across
71// dimensions given in axis. It needs to pass in init_value and reducer.
72template <typename T>
73inline void ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims,
74 T *output_data, const int *axis, const int64_t num_axis_dimensions,
75 T init_value, const int output_flat_size, T reducer(const T, const T))
76{
77 // Return early when input shape has zero dim.
78 for (int i = 0; i < input_num_dims; ++i)
79 {
80 if (input_dims[i] == 0)
81 return;
82 }
83
84 for (size_t idx = 0; idx < output_flat_size; ++idx)
85 {
86 output_data[idx] = init_value;
87 }
88
89 // Resolve axis.
90 int num_resolved_axis = 0;
91 if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, &num_resolved_axis))
92 {
93 return;
94 }
95
96 int temp_index[5];
97 // Reset input iterator.
98 for (int idx = 0; idx < input_num_dims; ++idx)
99 {
100 temp_index[idx] = 0;
101 }
102 // Iterate through input_data.
103 do
104 {
105 size_t input_offset = reducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
106 size_t output_offset =
107 reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis);
108 output_data[output_offset] = reducer(output_data[output_offset], input_data[input_offset]);
109 } while (nextIndex(input_num_dims, input_dims, temp_index));
110}
111
112} // namespace luci_interpreter_pal
113
114#endif // LUCI_INTERPRETER_PAL_TANH_H
void ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims, T *output_data, const int *axis, const int64_t num_axis_dimensions, T init_value, const int output_flat_size, T reducer(const T, const T))
bool nextIndex(const int num_dims, const int *dims, int *current)
Definition PALUtils.h:148
size_t reducedOutputOffset(const int num_dims, const int *dims, const int *index, const int num_axis, const int *axis)
Definition PALUtils.h:116
bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis, int *out_axis, int *out_num_axis)