ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALComparisons.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_COMPARISONS_H
19#define LUCI_INTERPRETER_PAL_COMPARISONS_H
20
21#include "Params.h"
23#include "PALUtils.h"
24
26{
27namespace
28{
29
30struct BroadcastComparison4DSlowCommon
31{
35};
36
37inline BroadcastComparison4DSlowCommon
38BroadcastComparison4DSlowPreprocess(const luci_interpreter::RuntimeShape &unextended_input1_shape,
39 const luci_interpreter::RuntimeShape &unextended_input2_shape,
40 const luci_interpreter::RuntimeShape &unextended_output_shape)
41{
44 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
45 &desc2);
46 return {luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_shape), desc1, desc2};
47}
48
49} // namespace
50
51template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
52template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
53template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
54template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
55template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
56template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
57
58template <typename T>
59inline void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data,
60 bool *output_data, bool F(T, T))
61{
62 for (int64_t i = 0; i < flat_size; ++i)
63 {
64 output_data[i] = F(input1_data[i], input2_data[i]);
65 }
66}
67
68template <typename T>
70 const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
71 const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
72 const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
73 bool *output_data, bool F(T, T))
74{
75 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
76 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
77
78 int left_shift = op_params.left_shift;
79 int32_t input1_offset = op_params.input1_offset;
80 int32_t input1_multiplier = op_params.input1_multiplier;
81 int input1_shift = op_params.input1_shift;
82 int32_t input2_offset = op_params.input2_offset;
83 int32_t input2_multiplier = op_params.input2_multiplier;
84 int input2_shift = op_params.input2_shift;
85
86 for (int b = 0; b < dims.output_shape.dims(0); ++b)
87 {
88 for (int y = 0; y < dims.output_shape.dims(1); ++y)
89 {
90 for (int x = 0; x < dims.output_shape.dims(2); ++x)
91 {
92 for (int c = 0; c < dims.output_shape.dims(3); ++c)
93 {
94 const int32_t input1_val =
95 input1_offset + input1_data[subscriptToIndex(dims.desc1, b, y, x, c)];
96 const int32_t input2_val =
97 input2_offset + input2_data[subscriptToIndex(dims.desc2, b, y, x, c)];
98 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
99 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
100 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
101 shifted_input1_val, input1_multiplier, input1_shift);
102 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
103 shifted_input2_val, input2_multiplier, input2_shift);
104
105 const int output_data_offset =
106 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
107 dims.output_shape.dims(3) +
108 c;
109 output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
110 }
111 }
112 }
113 }
114}
115
116template <typename T>
117inline void ComparisonWithScaling(const ComparisonParams &op_params, const int64_t flat_size,
118 const T *input1_data, const T *input2_data, bool *output_data,
119 bool F(T, T))
120{
121 int left_shift = op_params.left_shift;
122 int32_t input1_offset = op_params.input1_offset;
123 int32_t input1_multiplier = op_params.input1_multiplier;
124 int input1_shift = op_params.input1_shift;
125 int32_t input2_offset = op_params.input2_offset;
126 int32_t input2_multiplier = op_params.input2_multiplier;
127 int input2_shift = op_params.input2_shift;
128
129 for (int64_t i = 0; i < flat_size; ++i)
130 {
131 const int32_t input1_val = input1_offset + input1_data[i];
132 const int32_t input2_val = input2_offset + input2_data[i];
133 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
134 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
135 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
136 shifted_input1_val, input1_multiplier, input1_shift);
137 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
138 shifted_input2_val, input2_multiplier, input2_shift);
139 output_data[i] = F(scaled_input1_val, scaled_input2_val);
140 }
141}
142
143template <typename T>
145 const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
146 const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
147 const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
148 bool *output_data, bool F(T, T))
149{
150 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
151 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
152
153 for (int b = 0; b < dims.output_shape.dims(0); ++b)
154 {
155 for (int y = 0; y < dims.output_shape.dims(1); ++y)
156 {
157 for (int x = 0; x < dims.output_shape.dims(2); ++x)
158 {
159 for (int c = 0; c < dims.output_shape.dims(3); ++c)
160 {
161 const int output_data_offset =
162 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
163 dims.output_shape.dims(3) +
164 c;
165 output_data[output_data_offset] =
166 F(input1_data[subscriptToIndex(dims.desc1, b, y, x, c)],
167 input2_data[subscriptToIndex(dims.desc2, b, y, x, c)]);
168 }
169 }
170 }
171 }
172}
173
174} // namespace luci_interpreter_pal
175
176#endif // LUCI_INTERPRETER_PAL_COMPARISONS_H
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
Definition Tensor.h:95
NdArrayDesc< 4 > desc1
NdArrayDesc< 4 > desc2
bool LessFn(T lhs, T rhs)
void BroadcastComparison4DSlowWithScaling(const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape, const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape, const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
bool EqualFn(T lhs, T rhs)
void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
int32_t multiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
Definition PALUtils.h:85
bool LessEqualFn(T lhs, T rhs)
void ComparisonWithScaling(const ComparisonParams &op_params, const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape, const luci_interpreter::RuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void BroadcastComparison4DSlowNoScaling(const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape, const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape, const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
bool NotEqualFn(T lhs, T rhs)
bool GreaterEqualFn(T lhs, T rhs)
bool GreaterFn(T lhs, T rhs)