33constexpr uint32_t inputTensorIdx = 0;
34constexpr uint32_t weightTensorIdx = 1;
35constexpr uint32_t biasTensorIdx = 2;
50 uint16_t op_index =
args.kernel_index;
52 const circle::Tensor *
input;
53 const circle::Tensor *weight;
54 const circle::Tensor *
output;
56 int32_t weight_tensor_index = -1;
59 uint8_t *dloss_dinput_data;
62 uint8_t *dloss_dweight_data;
65 uint8_t *dloss_dbias_data;
68 uint8_t *dloss_doutput_data;
70 const circle::FullyConnectedOptions *
options;
77 weight = runtime_kernel.
inputs[weightTensorIdx];
79 assert(input !=
nullptr);
80 assert(weight !=
nullptr);
82 assert(output !=
nullptr);
84 weight_tensor_index = runtime_kernel.
inputs_index[weightTensorIdx];
85 assert(weight_tensor_index != -1);
92 weight_data = runtime_kernel.
inputs_data[weightTensorIdx];
93 bias_data = runtime_kernel.
inputs_data[biasTensorIdx];
98 assert(weight_data !=
nullptr);
105 dloss_dinput_data = runtime_kernel.
inputs_data[inputTensorIdx];
106 dloss_dweight_data = runtime_kernel.
inputs_data[weightTensorIdx];
107 dloss_dbias_data = runtime_kernel.
inputs_data[biasTensorIdx];
111 assert(dloss_dweight_data !=
nullptr);
112 assert(dloss_doutput_data !=
nullptr);
122 switch (
options->fused_activation_function())
124 case circle::ActivationFunctionType_NONE:
127 case circle::ActivationFunctionType_RELU:
129 assert(output_data !=
nullptr);
131 utils::castOutputData<float>(dloss_doutput_data),
output_shape);
136 assert(
false &&
"Unsupported activation type");
141 if (
args.is_trainable_layer)
146 assert(input_data !=
nullptr);
152 weight_shape = dynamic_shapes;
156 const auto kDlossSizeInBytes =
output_shape.
dims(1) * input_shape.dims(1) *
sizeof(float);
157 for (
int i = 0; i < kDlossSizeInBytes; i +=
sizeof(float))
158 *
static_cast<float *
>(
static_cast<void *
>(dloss_dweight_data + i)) = 0;
161 core::utils::castInputData<float>(dloss_doutput_data),
output_shape,
162 core::utils::castInputData<float>(input_data), input_shape,
163 core::utils::castOutputData<float>(dloss_dweight_data), weight_shape,
args.train_rank_type);
169 if (dloss_dbias_data)
171 assert(bias_data !=
nullptr);
172 if (bias_data ==
nullptr)
175 std::memcpy(dloss_dbias_data, dloss_doutput_data,
182 if (
args.is_last_layer ==
false)
184 assert(dloss_dinput_data !=
nullptr);
187 output_shape, core::utils::castInputData<float>(weight_data),
189 core::utils::castOutputData<float>(dloss_dinput_data));
int32_t dimensionsCount() const
int32_t dims(int i) const
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
int32_t inputs_index[maxInputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx
OMDataType
"scalar" value type
void ReluInputGrad(const float *input_relu_data, float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape)
void FullyConnectedInputGrad(const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *weight_data, const core::OMRuntimeShape &weight_shape, float *dloss_dinput_data)
void FullyConnectedWeightGrad(const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data, const core::OMRuntimeShape &weight_shape, core::OpTrainableRankType rank)