ONE - On-device Neural Engine
Loading...
Searching...
No Matches
arm_compute::CLTransposeConvLayer Class Reference

#include <CLTransposeConvLayer.h>

Collaboration diagram for arm_compute::CLTransposeConvLayer:

Public Member Functions

 CLTransposeConvLayer (std::shared_ptr< IMemoryManager > memory_manager=nullptr)
 
void configure (ICLTensor *input, ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info=WeightsInfo())
 
void configure (const CLCompileContext &compile_context, ICLTensor *input, ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info=WeightsInfo())
 
void run () override
 
void prepare () override
 

Static Public Member Functions

static Status validate (const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info=WeightsInfo())
 
static DeconvolutionMethod get_deconvolution_method (const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info)
 

Detailed Description

Basic function to compute the deconvolution layer. This function calls the following OpenCL kernels/functions:

  1. CLGEMMDeconvolutionLayer
  2. CLDirectTransposeConvLayer

Definition at line 58 of file CLTransposeConvLayer.h.

Constructor & Destructor Documentation

◆ CLTransposeConvLayer()

CLTransposeConvLayer::CLTransposeConvLayer ( std::shared_ptr< IMemoryManager >  memory_manager = nullptr)

Default constructor

Definition at line 56 of file CLTransposeConvLayer.cpp.

57 : _memory_manager(std::move(memory_manager)), _function()
58{
59}

Member Function Documentation

◆ configure() [1/2]

void CLTransposeConvLayer::configure ( const CLCompileContext &  compile_context,
ICLTensor *  input,
ICLTensor *  weights,
const ICLTensor *  bias,
ICLTensor *  output,
const PadStrideInfo &  deconv_info,
unsigned int  invalid_right,
unsigned int  invalid_bottom,
const WeightsInfo &  weights_info = WeightsInfo() 
)

Set the input, weights, biases and output tensors.

Parameters
[in]compile_contextThe compile context to be used.
[in,out]inputInput tensor. 3 lower dimensions represent a single input, and an optional 4th dimension for batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32.
[in]weightsThe 4d weights with dimensions [width, height, IFM, OFM]. Data type supported: Same as input.
[in]bias(Optional) The biases have one dimension. Data type supported: Same as input.
[out]outputOutput tensor. The output has the same number of dimensions as the input.
[in]deconv_infoContains padding and policies to be used in the deconvolution, this is described in PadStrideInfo.
[in]invalid_rightThe number of zeros added to right edge of the output.
[in]invalid_bottomThe number of zeros added to bottom edge of the output.
[in]weights_info(Optional) Weights information needed for CLConvolutionLayer, specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel.

Definition at line 70 of file CLTransposeConvLayer.cpp.

74{
75 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
76
77 switch (CLTransposeConvLayer::get_deconvolution_method(input->info(), weights->info(), nullptr,
78 output->info(), deconv_info, invalid_right,
79 invalid_bottom, weights_info))
80 {
81 case DeconvolutionMethod::DIRECT:
82 {
83 auto f = std::make_unique<CLDirectTransposeConvLayer>();
84 f->configure(compile_context, input, weights, bias, output, deconv_info, invalid_right,
85 invalid_bottom, weights_info);
86 _function = std::move(f);
87 break;
88 }
89 case DeconvolutionMethod::GEMM:
90 {
91 auto f = std::make_unique<CLGEMMDeconvolutionLayer>(_memory_manager);
92 f->configure(compile_context, input, weights, bias, output, deconv_info);
93 _function = std::move(f);
94 break;
95 }
96 default:
97 ARM_COMPUTE_ERROR("Not supported.");
98 break;
99 }
100}
static DeconvolutionMethod get_deconvolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info)

References get_deconvolution_method().

◆ configure() [2/2]

void CLTransposeConvLayer::configure ( ICLTensor *  input,
ICLTensor *  weights,
const ICLTensor *  bias,
ICLTensor *  output,
const PadStrideInfo &  deconv_info,
unsigned int  invalid_right,
unsigned int  invalid_bottom,
const WeightsInfo &  weights_info = WeightsInfo() 
)

Set the input, weights, biases and output tensors.

Parameters
[in,out]inputInput tensor. 3 lower dimensions represent a single input, and an optional 4th dimension for batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32.
[in]weightsThe 4d weights with dimensions [width, height, IFM, OFM]. Data type supported: Same as input.
[in]bias(Optional) The biases have one dimension. Data type supported: Same as input.
[out]outputOutput tensor. The output has the same number of dimensions as the input.
[in]deconv_infoContains padding and policies to be used in the deconvolution, this is described in PadStrideInfo.
[in]invalid_rightThe number of zeros added to right edge of the output.
[in]invalid_bottomThe number of zeros added to bottom edge of the output.
[in]weights_info(Optional) Weights information needed for CLConvolutionLayer, specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel.

Definition at line 61 of file CLTransposeConvLayer.cpp.

65{
66 configure(CLKernelLibrary::get().get_compile_context(), input, weights, bias, output, deconv_info,
67 invalid_right, invalid_bottom, weights_info);
68}
void configure(ICLTensor *input, ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const PadStrideInfo &deconv_info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info=WeightsInfo())

References configure().

Referenced by configure().

◆ get_deconvolution_method()

DeconvolutionMethod CLTransposeConvLayer::get_deconvolution_method ( const ITensorInfo *  input,
const ITensorInfo *  weights,
const ITensorInfo *  bias,
ITensorInfo *  output,
const PadStrideInfo &  deconv_info,
unsigned int  invalid_right,
unsigned int  invalid_bottom,
const WeightsInfo &  weights_info 
)
static

Definition at line 133 of file CLTransposeConvLayer.cpp.

137{
138 ARM_COMPUTE_UNUSED(output, bias, weights_info);
139
140 const DataLayout data_layout = input->data_layout();
141
142 const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
143 const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
144
145 if (weights->dimension(idx_w) != deconv_info.stride().first ||
146 weights->dimension(idx_h) != deconv_info.stride().second || invalid_right != 0 ||
147 invalid_bottom != 0)
148 {
149 return DeconvolutionMethod::DIRECT;
150 }
151
152 return DeconvolutionMethod::GEMM;
153}

Referenced by configure(), and validate().

◆ prepare()

void CLTransposeConvLayer::prepare ( )
override

Definition at line 161 of file CLTransposeConvLayer.cpp.

161{ _function->prepare(); }

Referenced by run().

◆ run()

void CLTransposeConvLayer::run ( )
override

Definition at line 155 of file CLTransposeConvLayer.cpp.

156{
157 prepare();
158 _function->run();
159}

References prepare().

Referenced by package.infer.session::inference().

◆ validate()

Status CLTransposeConvLayer::validate ( const ITensorInfo *  input,
const ITensorInfo *  weights,
const ITensorInfo *  bias,
ITensorInfo *  output,
const PadStrideInfo &  deconv_info,
unsigned int  invalid_right,
unsigned int  invalid_bottom,
const WeightsInfo &  weights_info = WeightsInfo() 
)
static

Static function to check if given info will lead to a valid configuration of CLTransposeConvLayer

Parameters
[in]inputInput tensor info. 3 lower dimensions represent a single input, and an optional 4th dimension for batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32.
[in]weightsThe 4d weights info with dimensions [width, height, IFM, OFM]. Data type supported: Same as input.
[in]bias(Optional) The biases have one dimension. Data type supported: Same as input.
[in]outputOutput tensor info. The output has the same number of dimensions as the input.
[in]deconv_infoContains padding and policies to be used in the deconvolution, this is described in PadStrideInfo.
[in]invalid_rightThe number of zeros added to right edge of the output.
[in]invalid_bottomThe number of zeros added to bottom edge of the output.
[in]weights_info(Optional) Weights information needed for CLConvolutionLayer, specifies if the weights tensor has been reshaped with CLWeightsReshapeKernel.
Returns
a status

Definition at line 102 of file CLTransposeConvLayer.cpp.

106{
107 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
109 input, weights, bias, output, deconv_info, invalid_right, invalid_bottom, weights_info))
110 {
111 case DeconvolutionMethod::DIRECT:
112 {
113 // Validate direct convolution layer
114 ARM_COMPUTE_RETURN_ON_ERROR(CLDirectTransposeConvLayer::validate(
115 input, weights, bias, output, deconv_info, invalid_right, invalid_bottom, weights_info));
116 break;
117 }
118 case DeconvolutionMethod::GEMM:
119 {
120 // Validate gemm-based convolution layer
121 ARM_COMPUTE_RETURN_ON_ERROR(
122 CLGEMMDeconvolutionLayer::validate(input, weights, bias, output, deconv_info));
123 break;
124 }
125 default:
126 ARM_COMPUTE_ERROR("Not supported.");
127 break;
128 }
129
130 return Status{};
131}
static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const PadStrideInfo &info, unsigned int invalid_right, unsigned int invalid_bottom, const WeightsInfo &weights_info=WeightsInfo())

References get_deconvolution_method(), and arm_compute::CLDirectTransposeConvLayer::validate().


The documentation for this class was generated from the following files: