ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
PALComparisons.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_COMPARISONS_H
19#define ONERT_MICRO_EXECUTE_PAL_COMPARISONS_H
20
21#include "OMStatus.h"
22#include "core/OMRuntimeShape.h"
23#include "PALUtils.h"
25
26namespace onert_micro
27{
28namespace execute
29{
30namespace pal
31{
32
33namespace
34{
35
36struct BroadcastComparison4DSlowCommon
37{
38 const core::OMRuntimeShape output_shape;
41};
42
43inline BroadcastComparison4DSlowCommon
44BroadcastComparison4DSlowPreprocess(const core::OMRuntimeShape &unextended_input1_shape,
45 const core::OMRuntimeShape &unextended_input2_shape,
46 const core::OMRuntimeShape &unextended_output_shape)
47{
50 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
51 &desc2);
52 return {core::OMRuntimeShape::extendedShape(4, unextended_output_shape), desc1, desc2};
53}
54
55} // namespace
56
57template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
58template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
59template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
60template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
61template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
62template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
63
64template <typename T>
65inline void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data,
66 bool *output_data, bool F(T, T))
67{
68 for (int64_t i = 0; i < flat_size; ++i)
69 {
70 output_data[i] = F(input1_data[i], input2_data[i]);
71 }
72}
73
74template <typename T, typename AccType>
76 const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape,
77 const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data,
78 const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(AccType, AccType))
79{
80 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
81 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
82
83 int left_shift = op_params.left_shift;
84 int32_t input1_offset = op_params.input1_offset;
85 int32_t input1_multiplier = op_params.input1_multiplier;
86 int input1_shift = op_params.input1_shift;
87 int32_t input2_offset = op_params.input2_offset;
88 int32_t input2_multiplier = op_params.input2_multiplier;
89 int input2_shift = op_params.input2_shift;
90
91 for (int b = 0; b < dims.output_shape.dims(0); ++b)
92 {
93 for (int y = 0; y < dims.output_shape.dims(1); ++y)
94 {
95 for (int x = 0; x < dims.output_shape.dims(2); ++x)
96 {
97 for (int c = 0; c < dims.output_shape.dims(3); ++c)
98 {
99 const int32_t input1_val =
100 input1_offset + input1_data[subscriptToIndex(dims.desc1, b, y, x, c)];
101 const int32_t input2_val =
102 input2_offset + input2_data[subscriptToIndex(dims.desc2, b, y, x, c)];
103 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
104 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
105 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
106 shifted_input1_val, input1_multiplier, input1_shift);
107 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
108 shifted_input2_val, input2_multiplier, input2_shift);
109
110 const int output_data_offset =
111 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
112 dims.output_shape.dims(3) +
113 c;
114 output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
115 }
116 }
117 }
118 }
119}
120
121template <typename T, typename AccType>
122inline void ComparisonWithScaling(const core::ComparisonParams &op_params, const int64_t flat_size,
123 const T *input1_data, const T *input2_data, bool *output_data,
124 bool F(AccType, AccType))
125{
126 int left_shift = op_params.left_shift;
127 int32_t input1_offset = op_params.input1_offset;
128 int32_t input1_multiplier = op_params.input1_multiplier;
129 int input1_shift = op_params.input1_shift;
130 int32_t input2_offset = op_params.input2_offset;
131 int32_t input2_multiplier = op_params.input2_multiplier;
132 int input2_shift = op_params.input2_shift;
133
134 for (int64_t i = 0; i < flat_size; ++i)
135 {
136 const int32_t input1_val = input1_offset + input1_data[i];
137 const int32_t input2_val = input2_offset + input2_data[i];
138 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
139 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
140 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
141 shifted_input1_val, input1_multiplier, input1_shift);
142 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
143 shifted_input2_val, input2_multiplier, input2_shift);
144 output_data[i] = F(scaled_input1_val, scaled_input2_val);
145 }
146}
147
148template <typename T>
150 const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape,
151 const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data,
152 const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
153{
154 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
155 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
156
157 for (int b = 0; b < dims.output_shape.dims(0); ++b)
158 {
159 for (int y = 0; y < dims.output_shape.dims(1); ++y)
160 {
161 for (int x = 0; x < dims.output_shape.dims(2); ++x)
162 {
163 for (int c = 0; c < dims.output_shape.dims(3); ++c)
164 {
165 const int output_data_offset =
166 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
167 dims.output_shape.dims(3) +
168 c;
169 output_data[output_data_offset] =
170 F(input1_data[subscriptToIndex(dims.desc1, b, y, x, c)],
171 input2_data[subscriptToIndex(dims.desc2, b, y, x, c)]);
172 }
173 }
174 }
175 }
176}
177
178} // namespace pal
179} // namespace execute
180} // namespace onert_micro
181
182#endif // ONERT_MICRO_EXECUTE_PAL_BINARYOP_COMMON_H
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
NdArrayDesc< 4 > desc1
NdArrayDesc< 4 > desc2
void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
bool NotEqualFn(T lhs, T rhs)
bool LessEqualFn(T lhs, T rhs)
bool GreaterFn(T lhs, T rhs)
bool EqualFn(T lhs, T rhs)
int32_t multiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
Definition PALUtils.h:112
void NdArrayDescsForElementwiseBroadcast(const core::OMRuntimeShape &input0_shape, const core::OMRuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void ComparisonWithScaling(const core::ComparisonParams &op_params, const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(AccType, AccType))
void BroadcastComparison4DSlowWithScaling(const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape, const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data, const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(AccType, AccType))
bool GreaterEqualFn(T lhs, T rhs)
void BroadcastComparison4DSlowNoScaling(const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape, const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data, const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
bool LessFn(T lhs, T rhs)