18#ifndef LUCI_INTERPRETER_PAL_BINARYOPCOMMON_H
19#define LUCI_INTERPRETER_PAL_BINARYOPCOMMON_H
28template <typename T, std::enable_if_t<std::is_floating_point<T>::value,
bool> =
true>
33 return std::floor(
static_cast<double>(lhs) /
static_cast<double>(rhs));
36template <typename T, std::enable_if_t<std::is_floating_point<T>::value,
bool> =
true>
41 T trunc_mod = std::fmod(lhs, rhs);
42 return (trunc_mod != 0) && ((rhs < 0) != (trunc_mod < 0)) ? (trunc_mod + rhs) : trunc_mod;
47 T
operator()(T lhs, T rhs) {
return std::max(lhs, rhs); }
51 T
operator()(T lhs, T rhs) {
return std::min(lhs, rhs); }
55template <
typename T,
typename Fn>
56inline void BinaryOp(
const int flat_size,
const T *input1_data,
const T *input2_data,
60 for (
int i = 0; i < flat_size; ++i)
62 output_data[i] = func(input1_data[i], input2_data[i]);
66template <
typename T,
typename Fn>
68 const float *input1_data,
70 const float *input2_data,
94 for (
int b = 0; b < extended_output_shape.
dims(0); ++b)
96 for (
int y = 0; y < extended_output_shape.
dims(1); ++y)
98 for (
int x = 0; x < extended_output_shape.
dims(2); ++x)
100 for (
int c = 0; c < extended_output_shape.
dims(3); ++c)
102 const int output_data_offset =
103 ((b * extended_output_shape.
dims(1) + y) * extended_output_shape.
dims(2) + x) *
104 extended_output_shape.
dims(3) +
int32_t dims(int i) const
static RuntimeShape extendedShape(int new_shape_size, const RuntimeShape &shape)
const luci_interpreter::RuntimeShape output_shape
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
void BroadcastBinaryOp4DSlow(const luci_interpreter::RuntimeShape &input1_shape, const float *input1_data, const luci_interpreter::RuntimeShape &input2_shape, const float *input2_data, const luci_interpreter::RuntimeShape &output_shape, float *output_data)
void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape, const luci_interpreter::RuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
void BinaryOp(const int flat_size, const T *input1_data, const T *input2_data, T *output_data)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)