18#ifndef LUCI_INTERPRETER_PAL_ARITHMETICOPCOMMON_H
19#define LUCI_INTERPRETER_PAL_ARITHMETICOPCOMMON_H
28template <
typename T>
struct AddFn
32template <
typename T>
struct SubFn
36template <
typename T>
struct MulFn
40template <
typename T>
struct DivFn
46template <
typename T,
typename Fn>
48 const T *input2_data, T *output_data)
50 T activation_min, activation_max;
54 for (
int i = 0; i < flat_size; ++i)
56 std::min(std::max(func(input1_data[i], input2_data[i]), activation_min), activation_max);
59template <
typename T,
typename Fn>
61 const T *input_data,
const T scalar_value, T *output_data)
63 T activation_min, activation_max;
66 for (
int i = 0; i < flat_size; ++i)
68 std::min(std::max(func(input_data[i], scalar_value), activation_min), activation_max);
71template <
typename T,
typename Fn>
83 T activation_min, activation_max;
98 for (
int b = 0; b < extended_output_shape.
dims(0); ++b)
100 for (
int y = 0; y < extended_output_shape.
dims(1); ++y)
102 for (
int x = 0; x < extended_output_shape.
dims(2); ++x)
104 for (
int c = 0; c < extended_output_shape.
dims(3); ++c)
106 const int output_data_offset =
107 ((b * extended_output_shape.
dims(1) + y) * extended_output_shape.
dims(2) + x) *
108 extended_output_shape.
dims(3) +
111 output_data[output_data_offset] =
int32_t dims(int i) const
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
const luci_interpreter::RuntimeShape output_shape
void ArithmeticOp(const ArithmeticParams ¶ms, const int flat_size, const T *input1_data, const T *input2_data, T *output_data)
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
void getActivationParams(const P ¶ms, int32_t *min, int32_t *max)
void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape, const luci_interpreter::RuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void ArithmeticOpScalar(const ArithmeticParams ¶ms, const int flat_size, const T *input_data, const T scalar_value, T *output_data)
void BroadcastArithmeticOp4DSlow(const ArithmeticParams ¶ms, const luci_interpreter::RuntimeShape &input1_shape, const T *input1_data, const luci_interpreter::RuntimeShape &input2_shape, const T *input2_data, const luci_interpreter::RuntimeShape &output_shape, T *output_data)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)