19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/log_softmax.h>
23#include "PALLogSoftmax.h"
35 if (
input()->element_type() == DataType::U8)
40 tflite::SoftmaxParams params{};
42 params.table = _table;
44 luci_interpreter_pal::PopulateSoftmaxLookupTable(¶ms,
input()->scale(), params.beta);
51 switch (
input()->element_type())
53 case DataType::FLOAT32:
60 throw std::runtime_error(
"luci-intp LogSoftmax Unsupported type.");
64void LogSoftmax::evalFloat()
const
66 tflite::SoftmaxParams params{};
71void LogSoftmax::evalQuantized()
const
76 uint8_t *output_data = getTensorData<uint8_t>(
output());
77 const uint8_t *input_data = getTensorData<uint8_t>(
input());
78 const float beta = 1.0;
80 tflite::SoftmaxParams params{};
82 params.table =
const_cast<float *
>(_table);
86 luci_interpreter_pal::InitializeParams(¶ms, input_scale, beta);
87 luci_interpreter_pal::LogSoftmax(params, input_scale, input_shape, input_data,
output_shape,
void resize(const Shape &new_shape)
int32_t zero_point() const
LogSoftmax(const Tensor *input, Tensor *output)
void configure() override
void execute() const override
const Tensor * input() const
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)