ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18#include "core/OMUtils.h"
19#include "core/OMDataType.h"
24#include "PALReluInputGrad.h"
25
26using namespace onert_micro;
27using namespace onert_micro::core;
28using namespace onert_micro::train;
29
30namespace
31{
32
33constexpr uint32_t inputTensorIdx = 0;
34constexpr uint32_t weightTensorIdx = 1;
35constexpr uint32_t biasTensorIdx = 2;
36
37constexpr uint32_t outputTensorIdx = 0;
38
39} // namespace
40
41/*
42 * - Calculate weight gradient (with bias)
43 * - Calculate input gradient - Optional (not required if it is last op)
44 */
45OMStatus onert_micro::train::train_kernel_CircleFullyConnected(const OMBackpropExecuteArgs &args)
46{
47 core::OMRuntimeStorage &forward_storage = args.forward_storage;
48 core::OMRuntimeStorage &backward_storage = args.backward_storage;
49 core::OMRuntimeContext &context = args.backward_context;
50 uint16_t op_index = args.kernel_index;
51
52 const circle::Tensor *input;
53 const circle::Tensor *weight;
54 const circle::Tensor *output;
55
56 int32_t weight_tensor_index = -1;
57
58 uint8_t *input_data;
59 uint8_t *dloss_dinput_data;
60
61 uint8_t *weight_data;
62 uint8_t *dloss_dweight_data;
63
64 uint8_t *bias_data;
65 uint8_t *dloss_dbias_data;
66
67 uint8_t *output_data;
68 uint8_t *dloss_doutput_data;
69
70 const circle::FullyConnectedOptions *options;
71 // Read kernel
72 {
73 execute::OMRuntimeKernel runtime_kernel;
74 runtime_kernel.readKernel(op_index, context);
75
76 input = runtime_kernel.inputs[inputTensorIdx];
77 weight = runtime_kernel.inputs[weightTensorIdx];
78 output = runtime_kernel.outputs[outputTensorIdx];
79 assert(input != nullptr);
80 assert(weight != nullptr);
81 // Bias can be nullptr
82 assert(output != nullptr);
83
84 weight_tensor_index = runtime_kernel.inputs_index[weightTensorIdx];
85 assert(weight_tensor_index != -1);
86
87 // Read forward storage
88 {
89 runtime_kernel.getDataFromStorage(op_index, forward_storage, context);
90
91 input_data = runtime_kernel.inputs_data[inputTensorIdx];
92 weight_data = runtime_kernel.inputs_data[weightTensorIdx];
93 bias_data = runtime_kernel.inputs_data[biasTensorIdx];
95 // Bias_data can be nullptr
96 // Output_data can be nullptr
97 // Input_data can be nullptr
98 assert(weight_data != nullptr);
99 }
100
101 // Read backward storage
102 {
103 runtime_kernel.getDataFromStorage(op_index, backward_storage, context);
104
105 dloss_dinput_data = runtime_kernel.inputs_data[inputTensorIdx];
106 dloss_dweight_data = runtime_kernel.inputs_data[weightTensorIdx];
107 dloss_dbias_data = runtime_kernel.inputs_data[biasTensorIdx];
108 dloss_doutput_data = runtime_kernel.outputs_data[outputTensorIdx];
109 // Bias_data and dloss_dinput_data can be nullptr
110 // Note: dloss_dinput_data can be nullptr due to it can be last trainable node
111 assert(dloss_dweight_data != nullptr);
112 assert(dloss_doutput_data != nullptr);
113 }
114
115 options = runtime_kernel.first_operator->builtin_options_as_FullyConnectedOptions();
116 }
117
118 OMRuntimeShape input_shape(input);
120
121 // 1. Handle activation functions
122 switch (options->fused_activation_function())
123 {
124 case circle::ActivationFunctionType_NONE:
125 // Do nothing
126 break;
127 case circle::ActivationFunctionType_RELU:
128 {
129 assert(output_data != nullptr);
130 pal::ReluInputGrad(utils::castInputData<float>(output_data),
131 utils::castOutputData<float>(dloss_doutput_data), output_shape);
132 break;
133 }
134 default:
135 {
136 assert(false && "Unsupported activation type");
137 return UnsupportedType;
138 }
139 }
140
141 if (args.is_trainable_layer)
142 {
143 // Check is only bias updating
144 if (args.train_rank_type != ONLY_BIAS)
145 {
146 assert(input_data != nullptr); // FIX memory planner then
147
148 // Get weight shape
149 OMRuntimeShape weight_shape(weight);
150 OMRuntimeShape dynamic_shapes = backward_storage.getDynamicRuntimeShape(weight_tensor_index);
151 if (dynamic_shapes.flatSize() != 0)
152 weight_shape = dynamic_shapes;
153
154 // 2. Calculate weight gradient
155 // Init weight grads with zeros
156 const auto kDlossSizeInBytes = output_shape.dims(1) * input_shape.dims(1) * sizeof(float);
157 for (int i = 0; i < kDlossSizeInBytes; i += sizeof(float))
158 *static_cast<float *>(static_cast<void *>(dloss_dweight_data + i)) = 0;
159
161 core::utils::castInputData<float>(dloss_doutput_data), output_shape,
162 core::utils::castInputData<float>(input_data), input_shape,
163 core::utils::castOutputData<float>(dloss_dweight_data), weight_shape, args.train_rank_type);
164 }
165
166 // 3. Calculate bias gradient
167 // Just copy dloss_doutput_data to dloss_dbias_data
168 // TODO: introduce training inplace
169 if (dloss_dbias_data)
170 {
171 assert(bias_data != nullptr);
172 if (bias_data == nullptr)
173 return UnknownError;
174
175 std::memcpy(dloss_dbias_data, dloss_doutput_data,
176 sizeof(OMDataType(output->type())) *
178 }
179 }
180
181 // 4. Calculate (if needed) input grad
182 if (args.is_last_layer == false)
183 {
184 assert(dloss_dinput_data != nullptr);
185
186 pal::FullyConnectedInputGrad(core::utils::castInputData<float>(dloss_doutput_data),
187 output_shape, core::utils::castInputData<float>(weight_data),
188 OMRuntimeShape(weight),
189 core::utils::castOutputData<float>(dloss_dinput_data));
190 }
191
192 return Ok;
193}
int32_t dimensionsCount() const
Definition Tensor.h:106
int32_t dims(int i) const
Definition Tensor.h:108
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx
args
Definition infer.py:21
list input_data
Definition infer.py:29
OMDataType
"scalar" value type
Definition OMDataType.h:35
void ReluInputGrad(const float *input_relu_data, float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape)
void FullyConnectedInputGrad(const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *weight_data, const core::OMRuntimeShape &weight_shape, float *dloss_dinput_data)
void FullyConnectedWeightGrad(const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data, const core::OMRuntimeShape &weight_shape, core::OpTrainableRankType rank)
@ UnsupportedType
Definition OMStatus.h:26