19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/softmax.h>
22#include "PALSoftmax.h"
41 if (
input()->element_type() == DataType::U8 ||
input()->element_type() == DataType::S8)
45 output()->zero_point() == std::numeric_limits<int8_t>::min());
46 tflite::SoftmaxParams op_params{};
47 op_params.table = _table;
48 luci_interpreter_pal::PopulateSoftmaxLookupTable(&op_params,
input()->scale(),
params().beta);
55 switch (
input()->element_type())
57 case DataType::FLOAT32:
61 evalQuantized<int8_t>();
64 evalQuantized<uint8_t>();
67 throw std::runtime_error(
"luci-intp Softmax Unsupported type.");
71void Softmax::evalFloat()
const
73 tflite::SoftmaxParams op_params{};
80template <
typename T>
void Softmax::evalQuantized()
const
82 tflite::SoftmaxParams op_params{};
83 op_params.table =
const_cast<float *
>(_table);
86 luci_interpreter_pal::InitializeParams(&op_params,
input()->
scale(),
params().beta);
const SoftmaxParams & params() const
void resize(const Shape &new_shape)
int32_t zero_point() const
void configure() override
const Tensor * input() const
Softmax(const Tensor *input, Tensor *output, const SoftmaxParams ¶ms)
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)