18#ifndef ONERT_MICRO_EXECUTE_PAL_COMPARISONS_H
19#define ONERT_MICRO_EXECUTE_PAL_COMPARISONS_H
36struct BroadcastComparison4DSlowCommon
38 const core::OMRuntimeShape output_shape;
43inline BroadcastComparison4DSlowCommon
44BroadcastComparison4DSlowPreprocess(
const core::OMRuntimeShape &unextended_input1_shape,
45 const core::OMRuntimeShape &unextended_input2_shape,
46 const core::OMRuntimeShape &unextended_output_shape)
57template <
typename T>
inline bool LessFn(T lhs, T rhs) {
return lhs < rhs; }
58template <
typename T>
inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs; }
59template <
typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs; }
60template <
typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs; }
61template <
typename T>
inline bool GreaterEqualFn(T lhs, T rhs) {
return lhs >= rhs; }
62template <
typename T>
inline bool NotEqualFn(T lhs, T rhs) {
return lhs != rhs; }
66 bool *output_data,
bool F(T, T))
68 for (int64_t i = 0; i < flat_size; ++i)
70 output_data[i] = F(input1_data[i], input2_data[i]);
74template <
typename T,
typename AccType>
78 const core::OMRuntimeShape &unextended_output_shape,
bool *output_data,
bool F(AccType, AccType))
80 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
81 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
91 for (
int b = 0; b < dims.output_shape.dims(0); ++b)
93 for (
int y = 0; y < dims.output_shape.dims(1); ++y)
95 for (
int x = 0; x < dims.output_shape.dims(2); ++x)
97 for (
int c = 0; c < dims.output_shape.dims(3); ++c)
99 const int32_t input1_val =
101 const int32_t input2_val =
103 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
104 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
106 shifted_input1_val, input1_multiplier, input1_shift);
108 shifted_input2_val, input2_multiplier, input2_shift);
110 const int output_data_offset =
111 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
112 dims.output_shape.dims(3) +
114 output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
121template <
typename T,
typename AccType>
123 const T *input1_data,
const T *input2_data,
bool *output_data,
124 bool F(AccType, AccType))
134 for (int64_t i = 0; i < flat_size; ++i)
136 const int32_t input1_val = input1_offset + input1_data[i];
137 const int32_t input2_val = input2_offset + input2_data[i];
138 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
139 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
141 shifted_input1_val, input1_multiplier, input1_shift);
143 shifted_input2_val, input2_multiplier, input2_shift);
144 output_data[i] = F(scaled_input1_val, scaled_input2_val);
151 const T *input1_data,
const core::OMRuntimeShape &unextended_input2_shape,
const T *input2_data,
154 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
155 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
157 for (
int b = 0; b < dims.output_shape.dims(0); ++b)
159 for (
int y = 0; y < dims.output_shape.dims(1); ++y)
161 for (
int x = 0; x < dims.output_shape.dims(2); ++x)
163 for (
int c = 0; c < dims.output_shape.dims(3); ++c)
165 const int output_data_offset =
166 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
167 dims.output_shape.dims(3) +
169 output_data[output_data_offset] =
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
bool NotEqualFn(T lhs, T rhs)
bool LessEqualFn(T lhs, T rhs)
bool GreaterFn(T lhs, T rhs)
bool EqualFn(T lhs, T rhs)
int32_t multiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
void NdArrayDescsForElementwiseBroadcast(const core::OMRuntimeShape &input0_shape, const core::OMRuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void ComparisonWithScaling(const core::ComparisonParams &op_params, const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(AccType, AccType))
void BroadcastComparison4DSlowWithScaling(const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape, const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data, const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(AccType, AccType))
bool GreaterEqualFn(T lhs, T rhs)
void BroadcastComparison4DSlowNoScaling(const core::ComparisonParams &op_params, const core::OMRuntimeShape &unextended_input1_shape, const T *input1_data, const core::OMRuntimeShape &unextended_input2_shape, const T *input2_data, const core::OMRuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
bool LessFn(T lhs, T rhs)
int32_t input2_multiplier
int32_t input1_multiplier