ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Conv2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18#include "core/OMUtils.h"
21#include "PALConv2DWeightGrad.h"
22#include "PALReluInputGrad.h"
23#include "PALConv2DInputGrad.h"
24
25using namespace onert_micro;
26using namespace onert_micro::core;
27using namespace onert_micro::train;
28
29namespace
30{
31
32constexpr uint32_t inputTensorIdx = 0;
33constexpr uint32_t weightTensorIdx = 1;
34constexpr uint32_t biasTensorIdx = 2;
35
36constexpr uint32_t outputTensorIdx = 0;
37
38} // namespace
39
40/*
41 * - Calculate weight gradient (with bias)
42 * - Calculate input gradient - Optional (not required if it is last op)
43 */
44OMStatus onert_micro::train::train_kernel_CircleConv2D(const OMBackpropExecuteArgs &args)
45{
46 core::OMRuntimeStorage &forward_storage = args.forward_storage;
47 core::OMRuntimeStorage &backward_storage = args.backward_storage;
48 core::OMRuntimeContext &context = args.backward_context;
49 uint16_t op_index = args.kernel_index;
50
51 const circle::Tensor *input;
52 const circle::Tensor *weight;
53 const circle::Tensor *output;
54
55 int32_t weight_tensor_index = -1;
56
57 uint8_t *input_data;
58 uint8_t *dloss_dinput_data;
59
60 uint8_t *weight_data;
61 uint8_t *dloss_dweight_data;
62
63 uint8_t *bias_data;
64 uint8_t *dloss_dbias_data;
65
66 uint8_t *output_data;
67 uint8_t *dloss_doutput_data;
68
69 const circle::Conv2DOptions *options;
70 // Read kernel
71 {
72 execute::OMRuntimeKernel runtime_kernel;
73 runtime_kernel.readKernel(op_index, context);
74
75 input = runtime_kernel.inputs[inputTensorIdx];
76 weight = runtime_kernel.inputs[weightTensorIdx];
77 output = runtime_kernel.outputs[outputTensorIdx];
78 assert(input != nullptr);
79 assert(weight != nullptr);
80 // Bias can be nullptr
81 assert(output != nullptr);
82
83 weight_tensor_index = runtime_kernel.inputs_index[weightTensorIdx];
84 assert(weight_tensor_index != -1);
85
86 // Read forward storage
87 {
88 runtime_kernel.getDataFromStorage(op_index, forward_storage, context);
89
90 input_data = runtime_kernel.inputs_data[inputTensorIdx];
91 weight_data = runtime_kernel.inputs_data[weightTensorIdx];
92 bias_data = runtime_kernel.inputs_data[biasTensorIdx];
94 // Bias_data can be nullptr
95 // Output_data can be nullptr
96 // Input_data can be nullptr if we don't train this layer
97 assert(weight_data != nullptr);
98 }
99
100 // Read backward storage
101 {
102 runtime_kernel.getDataFromStorage(op_index, backward_storage, context);
103
104 dloss_dinput_data = runtime_kernel.inputs_data[inputTensorIdx];
105 dloss_dweight_data = runtime_kernel.inputs_data[weightTensorIdx];
106 dloss_dbias_data = runtime_kernel.inputs_data[biasTensorIdx];
107 dloss_doutput_data = runtime_kernel.outputs_data[outputTensorIdx];
108 // Bias_data and dloss_dinput_data can be nullptr
109 // Note: dloss_dinput_data can be nullptr due to it can be last trainable node
110 assert(dloss_dweight_data != nullptr);
111 assert(dloss_doutput_data != nullptr);
112 }
113
114 options = runtime_kernel.first_operator->builtin_options_as_Conv2DOptions();
115 }
116
117 OMRuntimeShape input_shape(input);
119 OMRuntimeShape weight_shape(weight);
120
121 // 1. Handle activation functions
122 switch (options->fused_activation_function())
123 {
124 case circle::ActivationFunctionType_NONE:
125 // Do nothing
126 break;
127 case circle::ActivationFunctionType_RELU:
128 {
129 assert(output_data != nullptr);
130 pal::ReluInputGrad(utils::castInputData<float>(output_data),
131 utils::castOutputData<float>(dloss_doutput_data), output_shape);
132 break;
133 }
134 default:
135 {
136 assert(false && "Unsupported activation type");
137 return UnsupportedType;
138 }
139 }
140
141 const int input_width = input_shape.dims(3);
142 const int input_height = input_shape.dims(2);
143 const int weight_width = output_shape.dims(3);
144 const int weight_height = output_shape.dims(2);
145
146 FloatConv2D params{};
147
148 params.stride_w = options->stride_w();
149 params.stride_h = options->stride_h();
150 params.dilation_width_factor = options->dilation_w_factor();
151 params.dilation_height_factor = options->dilation_h_factor();
152 params.pad_h = 0;
153 params.pad_w = 0;
154
155 if (args.is_trainable_layer)
156 {
157 // Check is only bias updating
158 if (args.train_rank_type != ONLY_BIAS)
159 {
160 assert(input_data != nullptr); // FIX memory planner then
161
162 // Get weight shape
163 OMRuntimeShape dynamic_shapes = backward_storage.getDynamicRuntimeShape(weight_tensor_index);
164 if (dynamic_shapes.flatSize() != 0)
165 weight_shape = dynamic_shapes;
166
167 // 2. Calculate weight gradient
168 pal::Conv2DWeightGrad(params, input_shape, utils::castInputData<float>(input_data),
169 output_shape, utils::castInputData<float>(dloss_doutput_data),
170 weight_shape, utils::castOutputData<float>(dloss_dweight_data),
171 args.train_rank_type);
172 }
173
174 // 3. Calculate bias gradient
175 if (dloss_dbias_data)
176 {
177 assert(bias_data != nullptr);
178 if (bias_data == nullptr)
179 return UnknownError;
180
181 pal::Conv2DBiasGrad(output_shape, utils::castInputData<float>(dloss_doutput_data),
182 utils::castOutputData<float>(dloss_dbias_data));
183 }
184 }
185
186 // 4. Calculate (if needed) input grad
187 if (args.is_last_layer == false)
188 {
189 assert(dloss_dinput_data != nullptr);
190 pal::Conv2DInputGrad(params, weight_shape, utils::castInputData<float>(weight_data),
191 output_shape, utils::castInputData<float>(dloss_doutput_data), input_shape,
192 utils::castOutputData<float>(dloss_dinput_data));
193 }
194
195 return Ok;
196}
int32_t dims(int i) const
Definition Tensor.h:108
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx
args
Definition infer.py:21
list input_data
Definition infer.py:29
void ReluInputGrad(const float *input_relu_data, float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape)
void Conv2DWeightGrad(const core::FloatConv2D &params, const core::OMRuntimeShape &input_shape, const float *input_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_dweight_shape, float *dloss_dweight_data, core::OpTrainableRankType rank)
void Conv2DInputGrad(const core::FloatConv2D &params, const core::OMRuntimeShape &weight_shape, const float *weight_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_dinput_shape, float *dloss_dinput_data)
void Conv2DBiasGrad(const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, float *dloss_dbias_data)
@ UnsupportedType
Definition OMStatus.h:26