ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Conv2D Class Reference

#include <Conv2D.h>

Collaboration diagram for luci_interpreter::kernels::Conv2D:

Public Member Functions

 Conv2D (const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output, Tensor *scratchpad, const Conv2DParams &params)
 
const Tensorinput () const
 
const Tensorfilter () const
 
const Tensorbias () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< Conv2DParams >
const Conv2DParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< Conv2DParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const Conv2DParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< Conv2DParams >
const Conv2DParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 30 of file Conv2D.h.

Constructor & Destructor Documentation

◆ Conv2D()

luci_interpreter::kernels::Conv2D::Conv2D ( const Tensor input,
const Tensor filter,
const Tensor bias,
Tensor output,
Tensor scratchpad,
const Conv2DParams params 
)

Definition at line 32 of file Conv2D.cpp.

34 : KernelWithParams<Conv2DParams>({input, filter, bias}, {output, scratchpad}, params)
35{
36}
const Conv2DParams & params() const
Definition Kernel.h:67
const Tensor * input() const
Definition Conv2D.h:36
const Tensor * bias() const
Definition Conv2D.h:38
const Tensor * filter() const
Definition Conv2D.h:37

References bias(), filter(), and input().

Member Function Documentation

◆ bias()

const Tensor * luci_interpreter::kernels::Conv2D::bias ( ) const
inline

Definition at line 38 of file Conv2D.h.

38{ return _inputs[2]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Conv2D().

◆ configure()

void luci_interpreter::kernels::Conv2D::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 38 of file Conv2D.cpp.

39{
40 // TensorFlow Lite (as of v2.2.0) supports the following combinations of types:
41 // | input filter bias output |
42 // ----+---------------------------+
43 // (1) | float float float float |
44 // (2) | float int8 float float | hybrid
45 // (3) | uint8 uint8 int32 uint8 | quantized
46 // (4) | int8 int8 int32 int8 | quantized per channel
47 //
48 // We only support (1), (3) and (4) for now, and additionally the following:
49 // | input filter bias output |
50 // ----+---------------------------+
51 // (5) | int16 int16 int64 int16 |
52 //
53 if (input()->element_type() == DataType::FLOAT32 && filter()->element_type() == DataType::FLOAT32)
54 {
55 LUCI_INTERPRETER_CHECK(bias() == nullptr || bias()->element_type() == DataType::FLOAT32);
56 }
57 else if (input()->element_type() == DataType::U8 && filter()->element_type() == DataType::U8)
58 {
59 LUCI_INTERPRETER_CHECK(bias() == nullptr || bias()->element_type() == DataType::S32);
60 }
61 else if (input()->element_type() == DataType::S8 && filter()->element_type() == DataType::S8)
62 {
63 LUCI_INTERPRETER_CHECK(bias() == nullptr || bias()->element_type() == DataType::S32);
64 LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
65 LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
66 static_cast<size_t>(filter()->shape().dim(0)));
67 for (auto zerop : filter()->zero_points())
68 {
69 LUCI_INTERPRETER_CHECK(zerop == 0);
70 }
71 }
72 else if (input()->element_type() == DataType::S16 && filter()->element_type() == DataType::S16)
73 {
74 LUCI_INTERPRETER_CHECK(bias() == nullptr || bias()->element_type() == DataType::S64);
75 }
76 else
77 {
78 throw std::runtime_error("luci-intp Conv2D(1) Unsupported type.");
79 }
80 LUCI_INTERPRETER_CHECK(output()->element_type() == input()->element_type());
81
82 const Shape &input_shape = input()->shape();
83 const Shape &filter_shape = filter()->shape();
84 LUCI_INTERPRETER_CHECK(input_shape.num_dims() == 4 && filter_shape.num_dims() == 4);
85
86 const int32_t batches = input_shape.dim(0);
87 const int32_t input_height = input_shape.dim(1);
88 const int32_t input_width = input_shape.dim(2);
89 const int32_t output_depth = filter_shape.dim(0);
90 const int32_t filter_height = filter_shape.dim(1);
91 const int32_t filter_width = filter_shape.dim(2);
92 LUCI_INTERPRETER_CHECK(filter_shape.dim(3) == input_shape.dim(3));
93
94 LUCI_INTERPRETER_CHECK(bias() == nullptr || (bias()->shape().num_dims() == 1 &&
95 bias()->shape().dim(0) == output_depth));
96
97 const int32_t output_height =
98 computeOutputSize(_params.padding, input_height, filter_height, _params.stride_height,
100 const int32_t output_width =
101 computeOutputSize(_params.padding, input_width, filter_width, _params.stride_width,
103
105 input_height, filter_height, output_height);
106 _padding_width = computePadding(_params.stride_width, _params.dilation_width_factor, input_width,
107 filter_width, output_width);
108
109 output()->resize({batches, output_height, output_width, output_depth});
110
111 // Allocate tensor for scratchpad, if needed.
112 tflite::ConvParams params{};
113 params.padding_values.height = _padding_height;
114 params.padding_values.width = _padding_width;
119 auto scratchpad = getOutputTensors()[1];
120 luci_interpreter_pal::SetupScratchpadTensor(scratchpad, input()->element_type(), params,
123
124 switch (_params.activation)
125 {
126 case Activation::NONE:
127 case Activation::RELU:
128 case Activation::RELU6:
129 case Activation::RELU_N1_TO_1:
130 break;
131 default:
132 throw std::runtime_error("Unsupported fused activation");
133 }
134}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
int32_t computePadding(int32_t stride, int32_t dilation_rate, int32_t in_size, int32_t filter_size, int32_t out_size)
Definition Utils.h:41
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
int32_t computeOutputSize(Padding padding, int32_t image_size, int32_t filter_size, int32_t stride, int32_t dilation_rate=1)
Definition Utils.h:59
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28

References luci_interpreter::KernelWithParams< Conv2DParams >::_params, luci_interpreter::Conv2DParams::activation, bias(), luci_interpreter::kernels::computeOutputSize(), luci_interpreter::kernels::computePadding(), luci_interpreter::Conv2DParams::dilation_height_factor, luci_interpreter::Conv2DParams::dilation_width_factor, luci_interpreter::Shape::dim(), filter(), luci_interpreter::Kernel::getOutputTensors(), luci_interpreter::kernels::getTensorShape(), input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), output(), luci_interpreter::Conv2DParams::padding, luci_interpreter::KernelWithParams< Conv2DParams >::params(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), size, luci_interpreter::Conv2DParams::stride_height, luci_interpreter::Conv2DParams::stride_width, and luci_interpreter::Tensor::zero_points().

◆ execute()

void luci_interpreter::kernels::Conv2D::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 136 of file Conv2D.cpp.

137{
138 switch (input()->element_type())
139 {
140 case DataType::FLOAT32:
141 if (filter()->element_type() == DataType::FLOAT32)
142 {
143 evalFloat();
144 break;
145 }
146 throw std::runtime_error("luci-intp Conv2D(2) Unsupported type.");
147 case DataType::U8:
148 if (filter()->scales().size() == 1)
149 {
150 evalQuantized();
151 }
152 else if (filter()->scales().size() > 1)
153 {
154 LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
155 LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
156 static_cast<size_t>(filter()->shape().dim(0)));
157 evalQuantizedPerChannel();
158 }
159 break;
160 case DataType::S8:
161 evalQuantizedS8PerChannel();
162 break;
163 case DataType::S16:
164 evalQuantizedS16();
165 break;
166 default:
167 throw std::runtime_error("luci-intp Conv2D(3) Unsupported type.");
168 }
169}

References filter(), input(), LUCI_INTERPRETER_CHECK, and size.

◆ filter()

const Tensor * luci_interpreter::kernels::Conv2D::filter ( ) const
inline

Definition at line 37 of file Conv2D.h.

37{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), Conv2D(), and execute().

◆ input()

const Tensor * luci_interpreter::kernels::Conv2D::input ( ) const
inline

Definition at line 36 of file Conv2D.h.

36{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), Conv2D(), and execute().

◆ output()

Tensor * luci_interpreter::kernels::Conv2D::output ( ) const
inline

Definition at line 39 of file Conv2D.h.

39{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: