ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::SVDF Class Reference

#include <SVDF.h>

Collaboration diagram for luci_interpreter::kernels::SVDF:

Public Member Functions

 SVDF (const Tensor *input, const Tensor *weight_feature, const Tensor *weight_time, const Tensor *bias, const Tensor *input_activation_state, Tensor *output, Tensor *scratchpad_activation_state, Tensor *scratchpad_1, Tensor *scratchpad_2, Tensor *scratchpad_3, Tensor *scratchpad_4, Tensor *scratchpad_5, Tensor *scratchpad_6, const SVDFParams &params)
 
const Tensorinput () const
 
const Tensorweight_feature () const
 
const Tensorweight_time () const
 
const Tensorbias () const
 
const Tensorinput_activation_state () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< SVDFParams >
const SVDFParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< SVDFParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const SVDFParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< SVDFParams >
const SVDFParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file SVDF.h.

Constructor & Destructor Documentation

◆ SVDF()

luci_interpreter::kernels::SVDF::SVDF ( const Tensor input,
const Tensor weight_feature,
const Tensor weight_time,
const Tensor bias,
const Tensor input_activation_state,
Tensor output,
Tensor scratchpad_activation_state,
Tensor scratchpad_1,
Tensor scratchpad_2,
Tensor scratchpad_3,
Tensor scratchpad_4,
Tensor scratchpad_5,
Tensor scratchpad_6,
const SVDFParams params 
)

Definition at line 29 of file SVDF.cpp.

34 : KernelWithParams<SVDFParams>({input, weight_feature, weight_time, bias, input_activation_state},
35 {output, scratchpad_activation_state, scratchpad_1, scratchpad_2,
36 scratchpad_3, scratchpad_4, scratchpad_5, scratchpad_6},
37 params)
38{
39 // Do nothing
40}
const Tensor * input() const
Definition SVDF.h:37
const Tensor * input_activation_state() const
Definition SVDF.h:41
const Tensor * bias() const
Definition SVDF.h:40
const Tensor * weight_feature() const
Definition SVDF.h:38
const Tensor * weight_time() const
Definition SVDF.h:39
Tensor * output() const
Definition SVDF.h:43

References bias(), input(), input_activation_state(), weight_feature(), and weight_time().

Member Function Documentation

◆ bias()

const Tensor * luci_interpreter::kernels::SVDF::bias ( ) const
inline

Definition at line 40 of file SVDF.h.

40{ return _inputs[3]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and SVDF().

◆ configure()

void luci_interpreter::kernels::SVDF::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 42 of file SVDF.cpp.

43{
44 const Shape &input_shape = input()->shape();
45 const Shape &weight_features_shape = weight_feature()->shape();
46 const Shape &weight_time_shape = weight_time()->shape();
47
48 // Validate Input Tensor:
49 LUCI_INTERPRETER_CHECK(input()->element_type() == loco::DataType::FLOAT32 ||
50 input()->element_type() == loco::DataType::S8);
51 LUCI_INTERPRETER_CHECK(input_shape.num_dims() == 2);
52
53 // Validate inputs and output types
54 if (input()->element_type() == loco::DataType::S8)
55 {
56 LUCI_INTERPRETER_CHECK(weight_feature()->element_type() == loco::DataType::S8);
57 LUCI_INTERPRETER_CHECK(weight_time()->element_type() == loco::DataType::S16 ||
58 weight_time()->element_type() == loco::DataType::S8);
59 if (bias())
60 LUCI_INTERPRETER_CHECK(bias()->element_type() == loco::DataType::S32);
61
62 LUCI_INTERPRETER_CHECK(input_activation_state()->element_type() == loco::DataType::S16 ||
63 input_activation_state()->element_type() == loco::DataType::S8);
64 LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::S8);
65
66 // Note: now tflite support only ReLU activation for integer SVDF
68 }
69 else if (weight_feature()->element_type() == loco::DataType::FLOAT32)
70 {
71 LUCI_INTERPRETER_CHECK(weight_feature()->element_type() == loco::DataType::FLOAT32);
72 LUCI_INTERPRETER_CHECK(weight_time()->element_type() == loco::DataType::FLOAT32);
73 LUCI_INTERPRETER_CHECK(input_activation_state()->element_type() == loco::DataType::FLOAT32);
74 if (bias())
75 LUCI_INTERPRETER_CHECK(bias()->element_type() == loco::DataType::FLOAT32);
76 LUCI_INTERPRETER_CHECK(output()->element_type() == loco::DataType::FLOAT32);
77 }
78 else if ((weight_feature()->element_type() == loco::DataType::U8 ||
79 weight_feature()->element_type() == loco::DataType::S8) &&
80 input()->element_type() == loco::DataType::FLOAT32)
81 {
82 // TODO:: support hybrid SVDF op
83 throw std::runtime_error("Hybrid type is not currently supported");
84 }
85 else
86 {
87 throw std::runtime_error("luci-intp SVDF Unsupported type.");
88 }
89
90 // Check all the parameters of tensor match within themselves and match the
91 // input configuration.
92 const int rank = params().svdf_rank;
93 const int batch_size = input_shape.dim(0);
94 const int num_filters = weight_features_shape.dim(0);
95 LUCI_INTERPRETER_CHECK(rank != 0);
96 LUCI_INTERPRETER_CHECK(num_filters % rank == 0);
97
98 const int num_units = num_filters / rank;
99 const int memory_size = weight_time_shape.dim(1);
100
101 // Validate Weight_Feature Input Tensor:
102 LUCI_INTERPRETER_CHECK(weight_features_shape.num_dims() == 2);
103 LUCI_INTERPRETER_CHECK(weight_features_shape.dim(1) == input_shape.dim(1));
104
105 // Validate Weight_Time Input Tensor:
106 LUCI_INTERPRETER_CHECK(weight_time_shape.num_dims() == 2);
107 LUCI_INTERPRETER_CHECK(weight_time_shape.dim(0) == num_filters);
108
109 // Validate Bias
110 if (bias())
111 LUCI_INTERPRETER_CHECK(bias()->shape().dim(0) == num_units);
112
113 // Validate Input Activation State
114 LUCI_INTERPRETER_CHECK(input_activation_state()->shape().num_dims() == 2);
115 LUCI_INTERPRETER_CHECK(input_activation_state()->shape().dim(0) == batch_size);
116 LUCI_INTERPRETER_CHECK(input_activation_state()->shape().dim(1) == memory_size * num_filters);
117
118 // Resize scratchpad_state to input_activation_state
119 auto scratchpad_activation_state = getOutputTensors()[1];
120 scratchpad_activation_state->resize({batch_size, memory_size * num_filters});
121
122 // Resize output tensor
123 output()->resize({batch_size, num_units});
124
125 luci_interpreter_pal::SetupScratchpadTensor(
126 input()->element_type(), weight_feature()->element_type(), getOutputTensors()[2],
128 getOutputTensors()[7], input_shape, weight_time_shape, batch_size, num_filters, num_units);
129}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
Definition Shape.h:28

References bias(), luci_interpreter::Shape::dim(), luci_interpreter::Kernel::getOutputTensors(), input(), input_activation_state(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), luci_interpreter::KernelWithParams< SVDFParams >::params(), luci::RELU, luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), luci_interpreter::SVDFParams::svdf_rank, weight_feature(), and weight_time().

◆ execute()

void luci_interpreter::kernels::SVDF::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 131 of file SVDF.cpp.

132{
133 switch (weight_feature()->element_type())
134 {
135 case loco::DataType::FLOAT32:
136 evalFloat();
137 break;
138 case loco::DataType::S8:
139 {
140 if (input()->element_type() == loco::DataType::S8)
141 evalInteger();
142 else
143 // TODO:: support hybrid SVDF op
144 throw std::runtime_error("Hybrid type is not currently supported");
145 break;
146 }
147 default:
148 throw std::runtime_error("Unsupported type");
149 }
150}

References input(), and weight_feature().

◆ input()

const Tensor * luci_interpreter::kernels::SVDF::input ( ) const
inline

Definition at line 37 of file SVDF.h.

37{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and SVDF().

◆ input_activation_state()

const Tensor * luci_interpreter::kernels::SVDF::input_activation_state ( ) const
inline

Definition at line 41 of file SVDF.h.

41{ return _inputs[4]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and SVDF().

◆ output()

Tensor * luci_interpreter::kernels::SVDF::output ( ) const
inline

Definition at line 43 of file SVDF.h.

43{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ weight_feature()

const Tensor * luci_interpreter::kernels::SVDF::weight_feature ( ) const
inline

Definition at line 38 of file SVDF.h.

38{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and SVDF().

◆ weight_time()

const Tensor * luci_interpreter::kernels::SVDF::weight_time ( ) const
inline

Definition at line 39 of file SVDF.h.

39{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and SVDF().


The documentation for this class was generated from the following files: