ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Softmax Class Reference

#include <Softmax.h>

Collaboration diagram for luci_interpreter::kernels::Softmax:

Public Member Functions

 Softmax (const Tensor *input, Tensor *output, const SoftmaxParams &params)
 
const Tensorinput () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< SoftmaxParams >
const SoftmaxParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< SoftmaxParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const SoftmaxParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< SoftmaxParams >
const SoftmaxParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file Softmax.h.

Constructor & Destructor Documentation

◆ Softmax()

luci_interpreter::kernels::Softmax::Softmax ( const Tensor input,
Tensor output,
const SoftmaxParams params 
)

Definition at line 32 of file Softmax.cpp.

33 : KernelWithParams<SoftmaxParams>({input}, {output}, params)
34{
35}
const SoftmaxParams & params() const
Definition Kernel.h:67
const Tensor * input() const
Definition Softmax.h:33

References input().

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::Softmax::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 37 of file Softmax.cpp.

38{
39 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
40 LUCI_INTERPRETER_CHECK(input()->shape().num_dims() >= 1);
41 if (input()->element_type() == DataType::U8 || input()->element_type() == DataType::S8)
42 {
43 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8 || output()->zero_point() == 0);
44 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::U8 ||
45 output()->zero_point() == std::numeric_limits<int8_t>::min());
46 tflite::SoftmaxParams op_params{};
47 op_params.table = _table;
48 luci_interpreter_pal::PopulateSoftmaxLookupTable(&op_params, input()->scale(), params().beta);
49 }
50 output()->resize(input()->shape());
51}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36

References input(), LUCI_INTERPRETER_CHECK, output(), luci_interpreter::KernelWithParams< SoftmaxParams >::params(), and luci_interpreter::Tensor::resize().

◆ execute()

void luci_interpreter::kernels::Softmax::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 53 of file Softmax.cpp.

54{
55 switch (input()->element_type())
56 {
57 case DataType::FLOAT32:
58 evalFloat();
59 break;
60 case DataType::S8:
61 evalQuantized<int8_t>();
62 break;
63 case DataType::U8:
64 evalQuantized<uint8_t>();
65 break;
66 default:
67 throw std::runtime_error("luci-intp Softmax Unsupported type.");
68 }
69}

References input().

◆ input()

const Tensor * luci_interpreter::kernels::Softmax::input ( ) const
inline

Definition at line 33 of file Softmax.h.

33{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Softmax().

◆ output()

Tensor * luci_interpreter::kernels::Softmax::output ( ) const
inline

Definition at line 34 of file Softmax.h.

34{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: