18#include "kernels/Utils.h"
20#include "PALL2Normalize.h"
39 output()->element_type() == DataType::U8);
41 if (
output()->element_type() == DataType::U8)
52 switch (
output()->element_type())
54 case DataType::FLOAT32:
58 eval<uint8_t>(
input()->zero_point());
61 throw std::runtime_error(
"luci-intp L2Normalize Unsupported type.");
65template <
typename T>
void L2Normalize::eval(int32_t zero_point)
const
67 tflite::L2NormalizationParams op_params{};
68 op_params.input_zero_point = zero_point;
71 getTensorData<T>(
output()));
const L2NormParams & params() const
void resize(const Shape &new_shape)
L2Normalize(const Tensor *input, Tensor *output, const L2NormParams ¶ms)
void configure() override
const Tensor * input() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)