19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/softmax.h>
22#include "PALSoftmax.h"
41 if (
input()->element_type() == DataType::U8 ||
input()->element_type() == DataType::S8)
45 output()->zero_point() == std::numeric_limits<int8_t>::min());
55 switch (
input()->element_type())
57 case DataType::FLOAT32:
67 throw std::runtime_error(
"luci-intp Softmax Unsupported type.");
71void Softmax::evalFloat()
const
80template <
typename T>
void Softmax::evalQuantized()
const
83 op_params.table =
const_cast<float *
>(_table);
const SoftmaxParams & params() const
void resize(const Shape &new_shape)
int32_t zero_point() const
void configure() override
const Tensor * input() const
Softmax(const Tensor *input, Tensor *output, const SoftmaxParams ¶ms)
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)