ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::TransposeConv Class Reference

#include <TransposeConv.h>

Collaboration diagram for luci_interpreter::kernels::TransposeConv:

Public Member Functions

 TransposeConv (const Tensor *output_shape, const Tensor *filter, const Tensor *input, const Tensor *bias, Tensor *output, Tensor *scratch_tensor, const TransposeConvParams &params)
 
 ~TransposeConv ()
 
const Tensoroutput_shape () const
 
const Tensorfilter () const
 
const Tensorinput () const
 
const Tensorbias () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< TransposeConvParams >
const TransposeConvParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< TransposeConvParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const TransposeConvParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< TransposeConvParams >
const TransposeConvParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 30 of file TransposeConv.h.

Constructor & Destructor Documentation

◆ TransposeConv()

luci_interpreter::kernels::TransposeConv::TransposeConv ( const Tensor output_shape,
const Tensor filter,
const Tensor input,
const Tensor bias,
Tensor output,
Tensor scratch_tensor,
const TransposeConvParams params 
)

Definition at line 33 of file TransposeConv.cpp.

36 : KernelWithParams<TransposeConvParams>({output_shape, filter, input, bias},
37 {output, scratch_tensor}, params)
38{
39}
const TransposeConvParams & params() const
Definition Kernel.h:67

References bias(), filter(), input(), and output_shape().

◆ ~TransposeConv()

luci_interpreter::kernels::TransposeConv::~TransposeConv ( )

Definition at line 41 of file TransposeConv.cpp.

42{
43 // Define destructor here, to delete vector of qunatized multipliers properly
44}

Member Function Documentation

◆ bias()

const Tensor * luci_interpreter::kernels::TransposeConv::bias ( ) const
inline

Definition at line 42 of file TransposeConv.h.

42{ return _inputs[3]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by TransposeConv().

◆ configure()

void luci_interpreter::kernels::TransposeConv::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 46 of file TransposeConv.cpp.

47{
48 assert(output_shape()->shape().num_dims() == 1);
49 assert(input()->shape().num_dims() == 4);
50 assert(filter()->shape().num_dims() == 4);
51 assert(input()->element_type() == DataType::FLOAT32 || input()->element_type() == DataType::U8 ||
52 input()->element_type() == DataType::S16);
53 assert(input()->element_type() == output()->element_type());
54 assert(input()->shape().dim(3) == filter()->shape().dim(3));
55
56 const int num_dims = output_shape()->shape().dim(0);
57 Shape out_shape(num_dims);
58 const auto *shape_data = getTensorData<int32_t>(output_shape());
59 for (int i = 0; i < num_dims; i++)
60 out_shape.dim(i) = shape_data[i];
61 output()->resize(out_shape);
62
63 const int32_t filter_height = filter()->shape().dim(1);
64 const int32_t filter_width = filter()->shape().dim(2);
65 const int32_t output_height = out_shape.dim(1);
66 const int32_t output_width = out_shape.dim(2);
67
68 const int32_t unused_output_height =
69 computeOutputSize(params().padding, output_height, filter_height, params().stride_height, 1);
70 const int32_t unused_output_width =
71 computeOutputSize(params().padding, output_width, filter_width, params().stride_width, 1);
72
73 _padding_height =
74 computePadding(params().stride_height, 1, output_height, filter_height, unused_output_height);
75 _padding_width =
76 computePadding(params().stride_width, 1, output_width, filter_width, unused_output_width);
77
78 if (input()->element_type() == DataType::U8 || input()->element_type() == DataType::S16)
79 {
80 auto scratch_tensor = getOutputTensors()[1];
81 scratch_tensor->resize(output()->shape());
82 const std::vector<double> real_multipliers =
84
85 _quant_multipliers = quantizeMultipliers(real_multipliers);
86 }
87 else
88 {
89 auto scratch_tensor = getOutputTensors()[1];
90 scratch_tensor->set_allocatable(false);
91 }
92}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
int32_t dim(int i) const
Definition Tensor.h:41
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
int32_t computePadding(int32_t stride, int32_t dilation_rate, int32_t in_size, int32_t filter_size, int32_t out_size)
Definition Utils.h:41
std::vector< ChannelQuantMultipliers > quantizeMultipliers(const std::vector< double > &effective_scale)
Definition Utils.h:170
int32_t computeOutputSize(Padding padding, int32_t image_size, int32_t filter_size, int32_t stride, int32_t dilation_rate=1)
Definition Utils.h:59
std::vector< double > getQuantizedConvolutionMultiplers(float input_scale, const std::vector< float > &filter_scale, float output_scale)
Definition Utils.h:147
Definition Shape.h:28

References luci_interpreter::kernels::computeOutputSize(), luci_interpreter::kernels::computePadding(), luci_interpreter::Shape::dim(), filter(), luci_interpreter::Kernel::getOutputTensors(), luci_interpreter::kernels::getQuantizedConvolutionMultiplers(), input(), output(), output_shape(), luci_interpreter::KernelWithParams< TransposeConvParams >::params(), luci_interpreter::kernels::quantizeMultipliers(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::TransposeConv::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 94 of file TransposeConv.cpp.

95{
96 switch (input()->element_type())
97 {
98 case DataType::FLOAT32:
99 evalFloat();
100 break;
101 case DataType::U8:
102 if (filter()->scales().size() == 1)
103 {
104 evalQuantized();
105 }
106 else if (filter()->scales().size() > 1)
107 {
108 LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
109 LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
110 static_cast<size_t>(filter()->shape().dim(0)));
111 evalQuantizedPerChannel();
112 }
113 break;
114 case DataType::S16:
115 evalQuantizedS16();
116 break;
117 default:
118 throw std::runtime_error("luci-intp TransposeConv Unsupported type.");
119 }
120}
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
int32_t size[5]
Definition Slice.cpp:35

References filter(), input(), LUCI_INTERPRETER_CHECK, and size.

◆ filter()

const Tensor * luci_interpreter::kernels::TransposeConv::filter ( ) const
inline

Definition at line 40 of file TransposeConv.h.

40{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and TransposeConv().

◆ input()

const Tensor * luci_interpreter::kernels::TransposeConv::input ( ) const
inline

Definition at line 41 of file TransposeConv.h.

41{ return _inputs[2]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and TransposeConv().

◆ output()

Tensor * luci_interpreter::kernels::TransposeConv::output ( ) const
inline

Definition at line 43 of file TransposeConv.h.

43{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().

◆ output_shape()

const Tensor * luci_interpreter::kernels::TransposeConv::output_shape ( ) const
inline

Definition at line 39 of file TransposeConv.h.

39{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and TransposeConv().


The documentation for this class was generated from the following files: