19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
27template <
typename T> T FloorDivFunc(T input1, T input2)
31 float operator()(
const float lhs,
const float rhs)
const {
return std::fmod(lhs, rhs); }
34 typename std::conditional<std::is_integral<T>::value, std::modulus<T>, FloatMod>::type;
36 T trunc_mod = mod_func(input1, input2);
37 return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0)) ? (trunc_mod + input2) : trunc_mod;
60 switch (
x()->element_type())
62 case DataType::FLOAT32:
66 evalInteger<int8_t>();
69 evalInteger<int16_t>();
72 evalInteger<int32_t>();
75 evalInteger<int64_t>();
78 throw std::runtime_error(
"luci-intp FloorMod Unsupported type.");
82void FloorMod::evalFloat()
const
84 const auto x_data = getTensorData<float>(
x());
85 const auto y_data = getTensorData<float>(
y());
87 if (
x()->shape() !=
y()->shape())
89 tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
91 getTensorData<float>(
output()), FloorDivFunc<float>);
95 tflite::reference_ops::BinaryFunction<float, float, float>(
97 getTensorData<float>(
output()), FloorDivFunc<float>);
101template <
typename T>
void FloorMod::evalInteger()
const
103 const auto x_data = getTensorData<T>(
x());
104 const auto y_data = getTensorData<T>(
y());
108 if (y_data_type == DataType::S8 || y_data_type == DataType::S16 || y_data_type == DataType::S32 ||
109 y_data_type == DataType::S64)
117 if (
x()->shape() !=
y()->shape())
119 tflite::reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
121 getTensorData<T>(
output()), FloorDivFunc<T>);
127 getTensorData<T>(
output()), FloorDivFunc<T>);
void resize(const Shape &new_shape)
DataType element_type() const
FloorMod(const Tensor *x, const Tensor *y, Tensor *output)
void configure() override
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)