ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BinaryArithmetic.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__
18#define __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__
19
20#include "cker/Shape.h"
21#include "cker/eigen/Utils.h"
23
24namespace nnfw
25{
26namespace cker
27{
28namespace train
29{
31{
32 kAdd,
33 kSub,
34 kMul,
35 kDiv,
36};
37
38template <typename T>
39void BinaryArithmeticGrad(const Shape &lhs_shape, const T *lhs_data, const Shape &rhs_shape,
40 const T *rhs_data, const Shape &incoming_shape, const T *incoming_data,
41 const Shape &lhs_grad_shape, T *lhs_grad_data,
42 const Shape &rhs_grad_shape, T *rhs_grad_data,
43 ArithmeticType arithmetic_type)
44{
45 if (!(lhs_shape == rhs_shape && rhs_shape == incoming_shape && incoming_shape == lhs_grad_shape &&
46 lhs_grad_shape == rhs_grad_shape))
47 throw std::runtime_error{"Shape of lhs, rhs, incoming, lhs_grad, and rhs_grad must match"};
48
49 switch (arithmetic_type)
50 {
52 {
53 BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), lhs_grad_shape, lhs_grad_data);
54 BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), rhs_grad_shape, rhs_grad_data);
55 }
56 break;
57
59 {
60 BroadcastTo(incoming_shape, const_cast<T *>(incoming_data), lhs_grad_shape, lhs_grad_data);
61
62 auto const in_map = MapAsVector(incoming_data, incoming_shape);
63 auto rhs_grad_map = MapAsVector(rhs_grad_data, rhs_grad_shape);
64 rhs_grad_map = -in_map;
65 }
66 break;
67
69 {
70 auto const in_map = MapAsVector(incoming_data, incoming_shape);
71 auto const lhs_map = MapAsVector(lhs_data, lhs_shape);
72 auto const rhs_map = MapAsVector(rhs_data, rhs_shape);
73 auto lhs_grad_map = MapAsVector(lhs_grad_data, lhs_grad_shape);
74 auto rhs_grad_map = MapAsVector(rhs_grad_data, rhs_grad_shape);
75
76 lhs_grad_map = in_map.array() * rhs_map.array();
77 rhs_grad_map = in_map.array() * lhs_map.array();
78 }
79 break;
80
82 default:
83 throw std::runtime_error{"Unsupported Binary Arithmetic Operation"};
84 }
85}
86
87} // namespace train
88} // namespace cker
89} // namespace nnfw
90
91#endif // __NNFW_CKER_TRAIN_OPERATION_BINARYARITHMETIC_H__
void BinaryArithmeticGrad(const Shape &lhs_shape, const T *lhs_data, const Shape &rhs_shape, const T *rhs_data, const Shape &incoming_shape, const T *incoming_data, const Shape &lhs_grad_shape, T *lhs_grad_data, const Shape &rhs_grad_shape, T *rhs_grad_data, ArithmeticType arithmetic_type)
void BroadcastTo(const Shape &input_shape, T *input_data, const Shape &output_shape, T *output_data)
VectorMap< Scalar > MapAsVector(Scalar *data, const Shape &shape)
Definition Utils.h:43
Definition topk_v2.h:30