19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/binary_function.h>
27template <
typename T> T FloorDivFunc(T input1, T input2)
31 float operator()(
const float lhs,
const float rhs)
const {
return std::fmod(lhs, rhs); }
34 typename std::conditional<std::is_integral<T>::value, std::modulus<T>, FloatMod>
::type;
36 T trunc_mod = mod_func(input1, input2);
37 return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0)) ? (trunc_mod + input2) : trunc_mod;
60 switch (
x()->element_type())
62 case DataType::FLOAT32:
78 throw std::runtime_error(
"luci-intp FloorMod Unsupported type.");
82void FloorMod::evalFloat()
const
87 if (
x()->shape() !=
y()->shape())
89 tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
95 tflite::reference_ops::BinaryFunction<float, float, float>(
101template <
typename T>
void FloorMod::evalInteger()
const
117 if (
x()->shape() !=
y()->shape())
119 tflite::reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
void resize(const Shape &new_shape)
DataType element_type() const
FloorMod(const Tensor *x, const Tensor *y, Tensor *output)
void configure() override
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)