18#ifndef ONERT_MICRO_EXECUTE_PAL_BINARYOP_COMMON_H
19#define ONERT_MICRO_EXECUTE_PAL_BINARYOP_COMMON_H
34template <typename T, std::enable_if_t<std::is_floating_point<T>::value,
bool> =
true>
39 return std::floor(
static_cast<double>(lhs) /
static_cast<double>(rhs));
42template <typename T, std::enable_if_t<std::is_floating_point<T>::value,
bool> =
true>
47 T trunc_mod = std::fmod(lhs, rhs);
48 return (trunc_mod != 0) && ((rhs < 0) != (trunc_mod < 0)) ? (trunc_mod + rhs) : trunc_mod;
53 T
operator()(T lhs, T rhs) {
return std::max(lhs, rhs); }
57 T
operator()(T lhs, T rhs) {
return std::min(lhs, rhs); }
61template <
typename T,
typename Fn>
62inline OMStatus BinaryOp(
const int flat_size,
const T *input1_data,
const T *input2_data,
66 for (
int i = 0; i < flat_size; ++i)
68 output_data[i] = func(input1_data[i], input2_data[i]);
73template <
typename T,
typename Fn>
99 for (
int b = 0; b < extended_output_shape.
dims(0); ++b)
101 for (
int y = 0; y < extended_output_shape.
dims(1); ++y)
103 for (
int x = 0; x < extended_output_shape.
dims(2); ++x)
105 for (
int c = 0; c < extended_output_shape.
dims(3); ++c)
107 const int output_data_offset =
108 ((b * extended_output_shape.
dims(1) + y) * extended_output_shape.
dims(2) + x) *
109 extended_output_shape.
dims(3) +
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
int32_t dims(int i) const
const luci_interpreter::RuntimeShape output_shape
OMStatus BinaryOp(const int flat_size, const T *input1_data, const T *input2_data, T *output_data)
void NdArrayDescsForElementwiseBroadcast(const core::OMRuntimeShape &input0_shape, const core::OMRuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
OMStatus BroadcastBinaryOp4DSlow(const core::OMRuntimeShape &input1_shape, const float *input1_data, const core::OMRuntimeShape &input2_shape, const float *input2_data, const core::OMRuntimeShape &output_shape, float *output_data)
int subscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)
T operator()(T lhs, T rhs)