ONE - On-device Neural Engine
Loading...
Searching...
No Matches
OperationUtils.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OperationUtils.h"
18
22
23namespace onert
24{
25namespace backend
26{
27namespace train
28{
29namespace ops
30{
31
33{
34 if (tensor == nullptr)
35 return nnfw::cker::Shape();
36
37 assert(!tensor->is_dynamic() && "Dynamic tensor is not supported yet");
38
39 const ir::Shape &shape = tensor->get_info().shape();
40 auto rank = shape.rank();
41 nnfw::cker::Shape ret(rank);
42 auto data = ret.DimsData();
43 for (int i = 0; i < rank; ++i)
44 {
45 data[i] = shape.dim(i);
46 }
47 return ret;
48}
49
51 const IPortableTensor *output,
52 const IPortableTensor *input_backprop,
53 IPortableTensor *output_backprop)
54{
55 assert(output != nullptr);
56 assert(input_backprop != nullptr);
57
58 // handle NONE - just propagate incoming gradient
59 if (activation == ir::Activation::NONE)
60 {
61 return input_backprop;
62 }
63
64 assert(output_backprop != nullptr);
65
66 // handle other activation
67 switch (activation)
68 {
70 nnfw::cker::train::ReLUGrad(getShape(output), getBuffer<float>(output),
71 getShape(input_backprop), getBuffer<float>(input_backprop),
72 getShape(output_backprop), getBuffer<float>(output_backprop));
73 break;
75 nnfw::cker::train::ReLU6Grad(getShape(output), getBuffer<float>(output),
76 getShape(input_backprop), getBuffer<float>(input_backprop),
77 getShape(output_backprop), getBuffer<float>(output_backprop));
78 break;
79 // TODO: Add other activation backpropagation here
80 default:
81 throw std::runtime_error("Unsupported activation type yet");
82 }
83 return output_backprop;
84}
85
86void biasGrad(const IPortableTensor *input_backprop, IPortableTensor *bias_grad)
87{
88 assert(bias_grad);
89
90 nnfw::cker::Shape input_backprop_shape = getShape(input_backprop);
91 float *input_backprop_buffer = reinterpret_cast<float *>(input_backprop->buffer());
92
93 nnfw::cker::Shape bias_grad_shape = getShape(bias_grad);
94 float *bias_grad_buffer = getBuffer<float>(bias_grad);
95
96 nnfw::cker::functor::biasReductionHelper(input_backprop_buffer, input_backprop_shape,
97 bias_grad_buffer, bias_grad_shape);
98}
99
112
113} // namespace ops
114} // namespace train
115} // namespace backend
116} // namespace onert
int32_t * DimsData()
Definition Shape.h:112
A tensor class that is portable for other backends.
virtual uint8_t * buffer() const =0
void biasReductionHelper(float *input_backprop_buffer, const Shape &input_backprop_shape, float *bias_grad_buffer, const Shape &bias_grad_shape)
void ReLUGrad(const Shape &output_shape, const float *output_data, const Shape &incoming_shape, const float *incoming_data, const Shape &grad_shape, float *grad_data)
Definition ReLU.h:32
void ReLU6Grad(const Shape &output_shape, const float *output_data, const Shape &incoming_shape, const float *incoming_data, const Shape &grad_shape, float *grad_data)
Definition ReLU6.h:31
void biasGrad(const IPortableTensor *input_backprop, IPortableTensor *bias_grad)
backpropagate bias
const IPortableTensor * backpropActivation(const ir::Activation &activation, const IPortableTensor *output, const IPortableTensor *input_backprop, IPortableTensor *output_backprop)
backpropagate acitvation
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
Get shape of tensor.
nnfw::cker::train::LossReductionType convertLossReductionType(ir::train::LossReductionType type)
convert loss reduction type