ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ConvolutionLayer.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ConvolutionLayer.h"
18
19#include "OperationUtils.h"
20
21#include <cker/operation/Conv.h>
25
27
28namespace
29{
30
31using namespace onert;
32
33template <typename Tensor>
34std::unique_ptr<Tensor> createTransposedWeights(const backend::IPortableTensor *origin_weights)
35{
36 const auto &origin_shape = origin_weights->getShape();
37 assert(origin_shape.rank() == 4);
38
39 auto transposed_info = origin_weights->get_info();
40 // OHWI to HWIO
41 auto transposed_shape =
42 ir::Shape{origin_shape.dim(1), origin_shape.dim(2), origin_shape.dim(3), origin_shape.dim(0)};
43 transposed_info.shape(transposed_shape);
44
45 return std::make_unique<Tensor>(transposed_info);
46}
47
48} // namespace
49
50namespace onert
51{
52namespace backend
53{
54namespace train
55{
56namespace ops
57{
58
60 : cpu::ops::ConvolutionLayer(), _grad_weights{nullptr}, _grad_bias{nullptr},
61 _back_prop_input{nullptr}, _back_prop_output{nullptr}, _transposed_weights{nullptr}
62{
63 // DO NOTHING
64}
65
67
69 IPortableTensor *back_prop_input,
70 IPortableTensor *grad_weights, IPortableTensor *grad_bias,
71 const IPortableTensor *back_prop_output,
72 const ir::Activation activation)
73{
74 _back_prop_input = back_prop_input;
75 _grad_weights = grad_weights;
76 _grad_bias = grad_bias;
77 _back_prop_output = back_prop_output;
78
80 throw std::runtime_error("train ConvolutionLayer: Unsupported dilation yet");
81
82 // TODO Optimize transposed tensors
83 _transposed_weights = createTransposedWeights<Tensor>(weights);
84 _transposed_weights->setBuffer(
85 std::make_shared<basic::Allocator>(_transposed_weights->total_size()));
86
87 _conv_back_prop_output = std::make_unique<BackPropTensor>(back_prop_output->get_info());
88 _conv_back_prop_output->setBuffer(
89 std::make_shared<basic::Allocator>(_conv_back_prop_output->total_size()));
90
91 _transposed_grad_weights = createTransposedWeights<GradientTensor>(weights);
92 _transposed_grad_weights->setBuffer(
93 std::make_shared<basic::Allocator>(_transposed_grad_weights->total_size()));
94
95 if (activation != ir::Activation::NONE)
96 {
97 _act_back_prop_output = std::make_unique<BackPropTensor>(_back_prop_output->get_info());
98 _act_back_prop_output->setBuffer(
99 std::make_shared<basic::Allocator>(_act_back_prop_output->total_size()));
100 }
101}
102
105{
106 const auto data_type = _back_prop_output->data_type();
107 assert(data_type == _input->data_type());
108 switch (data_type)
109 {
110 case OperandType::FLOAT32:
111 {
112 assert(data_type == _grad_bias->data_type());
113 backwardFloat32();
114 break;
115 }
116 default:
117 throw std::runtime_error{"train ConvolutionLayer: unsupported data type"};
118 }
119}
120
121void ConvolutionLayer::backwardFloat32()
122{
123 // Calculate gradient for activation
124 const IPortableTensor *backprop_act;
125 try
126 {
127 backprop_act =
128 backpropActivation(_activation, _output, _back_prop_output, _act_back_prop_output.get());
129 }
130 catch (const std::exception &e)
131 {
132 throw std::runtime_error{"train ConvolutionLayer: " + std::string(e.what())};
133 }
134 assert(backprop_act != nullptr);
135
136 // Initialize conv params for training kernels
137 nnfw::cker::ConvParams conv_train_params;
138 conv_train_params.padding_type = getPaddingType(_paddingType);
139 conv_train_params.padding_values.width = _paddingLeft;
140 conv_train_params.padding_values.height = _paddingTop;
141 conv_train_params.stride_width = _strideWidth;
142 conv_train_params.stride_height = _strideHeight;
145
146 // Transpose weights from OHWI to HWIO
147 auto transposed_weights = _transposed_weights.get();
148 assert(transposed_weights->getShape().rank() == 4);
149 nnfw::cker::TransposeParams transpose_param;
150 transpose_param.perm_count = transposed_weights->getShape().rank();
151 transpose_param.perm[0] = 1;
152 transpose_param.perm[1] = 2;
153 transpose_param.perm[2] = 3;
154 transpose_param.perm[3] = 0;
155 nnfw::cker::Transpose(transpose_param, getShape(_kernel), getBuffer<float>(_kernel),
156 getShape(transposed_weights), getBuffer<float>(transposed_weights));
157
158 // Calculate gradient for input
160 conv_train_params, getShape(backprop_act), getBuffer<float>(backprop_act),
161 getShape(transposed_weights), getBuffer<float>(transposed_weights), _paddingBottom,
162 _paddingRight, getShape(_back_prop_input), getBuffer<float>(_back_prop_input));
163
164 // Calculate gradient for weights
165 auto transposed_grad_weights = _transposed_grad_weights.get();
166 assert(_grad_weights->getShape().rank() == 4);
167 assert(transposed_grad_weights->getShape().rank() == 4);
169 conv_train_params, getShape(backprop_act), getBuffer<float>(backprop_act), getShape(_input),
170 getBuffer<float>(_input), _paddingBottom, _paddingRight, getShape(transposed_grad_weights),
171 getBuffer<float>(transposed_grad_weights));
172
173 // Transpose weights'gradient from HWIO to OHWI
174 nnfw::cker::TransposeParams transpose_grad_param;
175 transpose_grad_param.perm_count = transposed_grad_weights->getShape().rank();
176 transpose_grad_param.perm[0] = 3;
177 transpose_grad_param.perm[1] = 0;
178 transpose_grad_param.perm[2] = 1;
179 transpose_grad_param.perm[3] = 2;
180 nnfw::cker::Transpose(transpose_grad_param, getShape(transposed_grad_weights),
181 getBuffer<float>(transposed_grad_weights), getShape(_grad_weights),
182 getBuffer<float>(_grad_weights));
183
184 // Calculate gradient for bias
185 if (_bias)
186 {
187 assert(_grad_bias);
188 biasGrad(backprop_act, _grad_bias);
189 }
190}
191
192} // namespace ops
193} // namespace train
194} // namespace backend
195} // namespace onert
A tensor class that is portable for other backends.
const ir::OperandInfo & get_info() const
ir::DataType data_type() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
void configureBackward(const IPortableTensor *weights, IPortableTensor *back_prop_input, IPortableTensor *grad_weights, IPortableTensor *grad_bias, const IPortableTensor *back_prop_output, const ir::Activation activation)
void ConvFilterGrad(const ConvParams &params, const Shape &incoming_shape, const float *incoming_data, const Shape &input_shape, const float *input_data, const int padding_bottom, const int padding_right, const Shape &filter_backprop_shape, float *filter_backprop_data)
Definition Conv.h:225
void ConvInputGrad(const ConvParams &params, const Shape &incoming_shape, const float *incoming_data, const Shape &filter_shape, const float *filter_data, const int padding_bottom, const int padding_right, const Shape &grad_shape, float *grad_data)
Definition Conv.h:188
void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape, const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
Definition Transpose.h:509
void biasGrad(const IPortableTensor *input_backprop, IPortableTensor *bias_grad)
backpropagate bias
const IPortableTensor * backpropActivation(const ir::Activation &activation, const IPortableTensor *output, const IPortableTensor *input_backprop, IPortableTensor *output_backprop)
backpropagate acitvation
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
Get shape of tensor.
int16_t stride_height
Definition Types.h:146
PaddingValues padding_values
Definition Types.h:143
int16_t dilation_width_factor
Definition Types.h:147
PaddingType padding_type
Definition Types.h:142
int16_t dilation_height_factor
Definition Types.h:148
int32_t dim(int i) const
Definition Shape.h:86