18#ifndef __NNFW_CKER_COMPARISON_H__
19#define __NNFW_CKER_COMPARISON_H__
30template <
typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs; }
31template <
typename T>
inline bool NotEqualFn(T lhs, T rhs) {
return lhs != rhs; }
32template <
typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs; }
33template <
typename T>
inline bool GreaterEqualFn(T lhs, T rhs) {
return lhs >= rhs; }
34template <
typename T>
inline bool LessFn(T lhs, T rhs) {
return lhs < rhs; }
35template <
typename T>
inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs; }
39template <
typename T, ComparisonFn<T> F>
41 const Shape &input2_shape,
const T *input2_data,
44 const int64_t flatsize =
46 for (int64_t i = 0; i < flatsize; ++i)
48 output_data[i] = F(input1_data[i], input2_data[i]);
52template <ComparisonFn<
float> F>
54 const Shape &input2_shape,
const float *input2_data,
57 ComparisonImpl<float, F>(input1_shape, input1_data, input2_shape, input2_data,
output_shape,
61template <
typename T, ComparisonFn<
int32_t> F>
63 const T *input1_data,
const Shape &input2_shape,
75 for (int64_t i = 0; i < flatsize; ++i)
77 const int32_t input1_val = input1_offset + input1_data[i];
78 const int32_t input2_val = input2_offset + input2_data[i];
79 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
80 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
82 shifted_input1_val, input1_multiplier, input1_shift);
84 shifted_input2_val, input2_multiplier, input2_shift);
85 output_data[i] = F(scaled_input1_val, scaled_input2_val);
89template <
typename T, ComparisonFn<T> F>
92 const Shape &unextended_input2_shape,
const T *input2_data,
93 const Shape &unextended_output_shape,
bool *output_data)
122template <
typename T, ComparisonFn<T> F>
124 const Shape &input2_shape,
const T *input2_data,
127 BroadcastComparison4DSlowImpl<T, F>(input1_shape, input1_data, input2_shape, input2_data,
131template <
typename T, ComparisonFn<
int32_t> F>
133 const Shape &input1_shape,
const T *input1_data,
134 const Shape &input2_shape,
const T *input2_data,
161 const int32_t input1_val =
163 const int32_t input2_val =
165 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
166 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
168 shifted_input1_val, input1_multiplier, input1_shift);
170 shifted_input2_val, input2_multiplier, input2_shift);
171 output_data[
Offset(
output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val);
178#define TFLITE_COMPARISON_OP(name) \
179 template <typename T> \
180 inline void name(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, \
181 const T *input2_data, const Shape &output_shape, bool *output_data) \
183 Comparison<name##Fn>(input1_shape, input1_data, input2_shape, input2_data, output_shape, \
186 template <typename T> \
187 inline void name##NoScaling(const Shape &input1_shape, const T *input1_data, \
188 const Shape &input2_shape, const T *input2_data, \
189 const Shape &output_shape, bool *output_data) \
191 ComparisonImpl<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data, \
192 output_shape, output_data); \
194 template <typename T> \
195 inline void name##WithScaling( \
196 ComparisonParams ¶ms, const Shape &input1_shape, const T *input1_data, \
197 const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data) \
199 ComparisonWithScaling<T, name##Fn>(params, input1_shape, input1_data, input2_shape, \
200 input2_data, output_shape, output_data); \
202 template <typename T> \
203 inline void Broadcast4DSlow##name##NoScaling(const Shape &input1_shape, const T *input1_data, \
204 const Shape &input2_shape, const T *input2_data, \
205 const Shape &output_shape, bool *output_data) \
207 BroadcastComparison4DSlowImpl<T, name##Fn>(input1_shape, input1_data, input2_shape, \
208 input2_data, output_shape, output_data); \
210 template <typename T> \
211 inline void Broadcast4DSlow##name(const Shape &input1_shape, const T *input1_data, \
212 const Shape &input2_shape, const T *input2_data, \
213 const Shape &output_shape, bool *output_data) \
215 BroadcastComparison4DSlow<T, name##Fn>(input1_shape, input1_data, input2_shape, input2_data, \
216 output_shape, output_data); \
218 template <typename T> \
219 inline void Broadcast4DSlow##name##WithScaling( \
220 ComparisonParams ¶ms, const Shape &input1_shape, const T *input1_data, \
221 const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data) \
223 BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
224 params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data); \
233#undef TFLITE_COMPARISON_OP
int32_t DimensionsCount() const
#define TFLITE_COMPARISON_OP(name)
const luci_interpreter::RuntimeShape output_shape
void Comparison(const Shape &input1_shape, const float *input1_data, const Shape &input2_shape, const float *input2_data, const Shape &output_shape, bool *output_data)
bool LessFn(T lhs, T rhs)
bool EqualFn(T lhs, T rhs)
void BroadcastComparison4DSlowWithScaling(ComparisonParams ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
bool(*)(T, T) ComparisonFn
void NdArrayDescsForElementwiseBroadcast(const Shape &input0_shape, const Shape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
bool GreaterEqualFn(T lhs, T rhs)
void ComparisonWithScaling(ComparisonParams ¶ms, const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data)
void ComparisonImpl(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data)
bool LessEqualFn(T lhs, T rhs)
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
int SubscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
void BroadcastComparison4DSlowImpl(const Shape &unextended_input1_shape, const T *input1_data, const Shape &unextended_input2_shape, const T *input2_data, const Shape &unextended_output_shape, bool *output_data)
void BroadcastComparison4DSlow(const Shape &input1_shape, const T *input1_data, const Shape &input2_shape, const T *input2_data, const Shape &output_shape, bool *output_data)
bool NotEqualFn(T lhs, T rhs)
bool GreaterFn(T lhs, T rhs)
int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
int32_t input2_multiplier
int32_t input1_multiplier