ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
ConvolutionLayer.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ConvolutionLayer.h"
18
19#include "OperationUtils.h"
20
21#include <cker/operation/Conv.h>
25
27
28namespace
29{
30
31using namespace onert;
32
33template <typename Tensor>
34std::unique_ptr<Tensor> createTransposedWeights(const backend::IPortableTensor *origin_weights)
35{
36 const auto &origin_shape = origin_weights->getShape();
37 assert(origin_shape.rank() == 4);
38
39 auto transposed_info = origin_weights->get_info();
40 // OHWI to HWIO
41 auto transposed_shape =
42 ir::Shape{origin_shape.dim(1), origin_shape.dim(2), origin_shape.dim(3), origin_shape.dim(0)};
43 transposed_info.shape(transposed_shape);
44
45 return std::make_unique<Tensor>(transposed_info);
46}
47
48} // namespace
49
51{
52
54 : cpu::ops::ConvolutionLayer(), _grad_weights{nullptr}, _grad_bias{nullptr},
55 _back_prop_input{nullptr}, _back_prop_output{nullptr}, _transposed_weights{nullptr}
56{
57 // DO NOTHING
58}
59
61
63 IPortableTensor *back_prop_input,
64 IPortableTensor *grad_weights, IPortableTensor *grad_bias,
65 const IPortableTensor *back_prop_output,
66 const ir::Activation activation)
67{
68 _back_prop_input = back_prop_input;
69 _grad_weights = grad_weights;
70 _grad_bias = grad_bias;
71 _back_prop_output = back_prop_output;
72
74 throw std::runtime_error("train ConvolutionLayer: Unsupported dilation yet");
75
76 // TODO Optimize transposed tensors
77 _transposed_weights = createTransposedWeights<Tensor>(weights);
78 _transposed_weights->setBuffer(
79 std::make_shared<basic::Allocator>(_transposed_weights->total_size()));
80
81 _conv_back_prop_output = std::make_unique<BackPropTensor>(back_prop_output->get_info());
82 _conv_back_prop_output->setBuffer(
83 std::make_shared<basic::Allocator>(_conv_back_prop_output->total_size()));
84
85 _transposed_grad_weights = createTransposedWeights<GradientTensor>(weights);
86 _transposed_grad_weights->setBuffer(
87 std::make_shared<basic::Allocator>(_transposed_grad_weights->total_size()));
88
89 if (activation != ir::Activation::NONE)
90 {
91 _act_back_prop_output = std::make_unique<BackPropTensor>(_back_prop_output->get_info());
92 _act_back_prop_output->setBuffer(
93 std::make_shared<basic::Allocator>(_act_back_prop_output->total_size()));
94 }
95}
96
99{
100 const auto data_type = _back_prop_output->data_type();
101 assert(data_type == _input->data_type());
102 switch (data_type)
103 {
104 case OperandType::FLOAT32:
105 {
106 assert(data_type == _grad_bias->data_type());
107 backwardFloat32();
108 break;
109 }
110 default:
111 throw std::runtime_error{"train ConvolutionLayer: unsupported data type"};
112 }
113}
114
115void ConvolutionLayer::backwardFloat32()
116{
117 // Calculate gradient for activation
118 const IPortableTensor *backprop_act;
119 try
120 {
121 backprop_act =
122 backpropActivation(_activation, _output, _back_prop_output, _act_back_prop_output.get());
123 }
124 catch (const std::exception &e)
125 {
126 throw std::runtime_error{"train ConvolutionLayer: " + std::string(e.what())};
127 }
128 assert(backprop_act != nullptr);
129
130 // Initialize conv params for training kernels
131 nnfw::cker::ConvParams conv_train_params;
132 conv_train_params.padding_type = getPaddingType(_paddingType);
133 conv_train_params.padding_values.width = _paddingLeft;
134 conv_train_params.padding_values.height = _paddingTop;
135 conv_train_params.stride_width = _strideWidth;
136 conv_train_params.stride_height = _strideHeight;
139
140 // Transpose weights from OHWI to HWIO
141 auto transposed_weights = _transposed_weights.get();
142 assert(transposed_weights->getShape().rank() == 4);
143 nnfw::cker::TransposeParams transpose_param;
144 transpose_param.perm_count = transposed_weights->getShape().rank();
145 transpose_param.perm[0] = 1;
146 transpose_param.perm[1] = 2;
147 transpose_param.perm[2] = 3;
148 transpose_param.perm[3] = 0;
149 nnfw::cker::Transpose(transpose_param, getShape(_kernel), getBuffer<float>(_kernel),
150 getShape(transposed_weights), getBuffer<float>(transposed_weights));
151
152 // Calculate gradient for input
154 conv_train_params, getShape(backprop_act), getBuffer<float>(backprop_act),
155 getShape(transposed_weights), getBuffer<float>(transposed_weights), _paddingBottom,
156 _paddingRight, getShape(_back_prop_input), getBuffer<float>(_back_prop_input));
157
158 // Calculate gradient for weights
159 auto transposed_grad_weights = _transposed_grad_weights.get();
160 assert(_grad_weights->getShape().rank() == 4);
161 assert(transposed_grad_weights->getShape().rank() == 4);
163 conv_train_params, getShape(backprop_act), getBuffer<float>(backprop_act), getShape(_input),
164 getBuffer<float>(_input), _paddingBottom, _paddingRight, getShape(transposed_grad_weights),
165 getBuffer<float>(transposed_grad_weights));
166
167 // Transpose weights'gradient from HWIO to OHWI
168 nnfw::cker::TransposeParams transpose_grad_param;
169 transpose_grad_param.perm_count = transposed_grad_weights->getShape().rank();
170 transpose_grad_param.perm[0] = 3;
171 transpose_grad_param.perm[1] = 0;
172 transpose_grad_param.perm[2] = 1;
173 transpose_grad_param.perm[3] = 2;
174 nnfw::cker::Transpose(transpose_grad_param, getShape(transposed_grad_weights),
175 getBuffer<float>(transposed_grad_weights), getShape(_grad_weights),
176 getBuffer<float>(_grad_weights));
177
178 // Calculate gradient for bias
179 if (_bias)
180 {
181 assert(_grad_bias);
182 biasGrad(backprop_act, _grad_bias);
183 }
184}
185
186} // namespace onert::backend::train::ops
A tensor class that is portable for other backends.
const ir::OperandInfo & get_info() const
ir::DataType data_type() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
void configureBackward(const IPortableTensor *weights, IPortableTensor *back_prop_input, IPortableTensor *grad_weights, IPortableTensor *grad_bias, const IPortableTensor *back_prop_output, const ir::Activation activation)
void ConvFilterGrad(const ConvParams &params, const Shape &incoming_shape, const float *incoming_data, const Shape &input_shape, const float *input_data, const int padding_bottom, const int padding_right, const Shape &filter_backprop_shape, float *filter_backprop_data)
Definition Conv.h:225
void ConvInputGrad(const ConvParams &params, const Shape &incoming_shape, const float *incoming_data, const Shape &filter_shape, const float *filter_data, const int padding_bottom, const int padding_right, const Shape &grad_shape, float *grad_data)
Definition Conv.h:188
void Transpose(const TransposeParams &unshrunk_params, const Shape &unshrunk_input_shape, const T *input_data, const Shape &unshrunk_output_shape, T *output_data)
Definition Transpose.h:512
void biasGrad(const IPortableTensor *input_backprop, IPortableTensor *bias_grad)
backpropagate bias
const IPortableTensor * backpropActivation(const ir::Activation &activation, const IPortableTensor *output, const IPortableTensor *input_backprop, IPortableTensor *output_backprop)
backpropagate acitvation
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
Get shape of tensor.
int16_t stride_height
Definition Types.h:146
PaddingValues padding_values
Definition Types.h:143
int16_t dilation_width_factor
Definition Types.h:147
PaddingType padding_type
Definition Types.h:142
int16_t dilation_height_factor
Definition Types.h:148
int32_t dim(int i) const
Definition Shape.h:85