20#include "kernels/Utils.h"
22#include "kernels/BinaryOpCommon.h"
43 switch (
input1()->element_type())
45 case DataType::FLOAT32:
46 evalSquaredDifference<float>();
49 throw std::runtime_error(
"luci-intp SquaredDifference Unsupported type.");
53template <
typename T>
inline void SquaredDifference::evalSquaredDifference()
const
58 const T difference = x - y;
59 return difference * difference;
void resize(const Shape &new_shape)
const Tensor * input1() const
void configure() override
const Tensor * input2() const
void execute() const override
SquaredDifference(const Tensor *input1, const Tensor *input2, Tensor *output)
#define LUCI_INTERPRETER_CHECK(cond)
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape, const T *input1_data, const tflite::RuntimeShape &unextended_input2_shape, const T *input2_data, const tflite::RuntimeShape &unextended_output_shape, T *output_data, Op op)