18#ifndef LUCI_INTERPRETER_PAL_COMPARISONS_H
19#define LUCI_INTERPRETER_PAL_COMPARISONS_H
30struct BroadcastComparison4DSlowCommon
37inline BroadcastComparison4DSlowCommon
51template <
typename T>
inline bool LessFn(T lhs, T rhs) {
return lhs < rhs; }
52template <
typename T>
inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs; }
53template <
typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs; }
54template <
typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs; }
55template <
typename T>
inline bool GreaterEqualFn(T lhs, T rhs) {
return lhs >= rhs; }
56template <
typename T>
inline bool NotEqualFn(T lhs, T rhs) {
return lhs != rhs; }
60 bool *output_data,
bool F(T, T))
62 for (int64_t i = 0; i < flat_size; ++i)
64 output_data[i] = F(input1_data[i], input2_data[i]);
73 bool *output_data,
bool F(T, T))
75 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
76 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
86 for (
int b = 0; b < dims.output_shape.dims(0); ++b)
88 for (
int y = 0; y < dims.output_shape.dims(1); ++y)
90 for (
int x = 0; x < dims.output_shape.dims(2); ++x)
92 for (
int c = 0; c < dims.output_shape.dims(3); ++c)
94 const int32_t input1_val =
96 const int32_t input2_val =
98 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
99 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
101 shifted_input1_val, input1_multiplier, input1_shift);
103 shifted_input2_val, input2_multiplier, input2_shift);
105 const int output_data_offset =
106 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
107 dims.output_shape.dims(3) +
109 output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
118 const T *input1_data,
const T *input2_data,
bool *output_data,
129 for (int64_t i = 0; i < flat_size; ++i)
131 const int32_t input1_val = input1_offset + input1_data[i];
132 const int32_t input2_val = input2_offset + input2_data[i];
133 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
134 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
136 shifted_input1_val, input1_multiplier, input1_shift);
138 shifted_input2_val, input2_multiplier, input2_shift);
139 output_data[i] = F(scaled_input1_val, scaled_input2_val);
148 bool *output_data,
bool F(T, T))
150 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
151 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
153 for (
int b = 0; b < dims.output_shape.dims(0); ++b)
155 for (
int y = 0; y < dims.output_shape.dims(1); ++y)
157 for (
int x = 0; x < dims.output_shape.dims(2); ++x)
159 for (
int c = 0; c < dims.output_shape.dims(3); ++c)
161 const int output_data_offset =
162 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
163 dims.output_shape.dims(3) +
165 output_data[output_data_offset] =
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
bool LessFn(T lhs, T rhs)
void BroadcastComparison4DSlowWithScaling(const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape, const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape, const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
bool EqualFn(T lhs, T rhs)
void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
int32_t multiplyByQuantizedMultiplierSmallerThanOneExp(int32_t x, int32_t quantized_multiplier, int left_shift)
bool LessEqualFn(T lhs, T rhs)
void ComparisonWithScaling(const ComparisonParams &op_params, const int64_t flat_size, const T *input1_data, const T *input2_data, bool *output_data, bool F(T, T))
void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape, const luci_interpreter::RuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void BroadcastComparison4DSlowNoScaling(const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape, const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape, const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape, bool *output_data, bool F(T, T))
bool NotEqualFn(T lhs, T rhs)
bool GreaterEqualFn(T lhs, T rhs)
bool GreaterFn(T lhs, T rhs)
int32_t input2_multiplier
int32_t input1_multiplier