ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ReLU.h
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#ifndef __NNFW_CKER_TRAIN_OPERATION_RELU_H__
18
#define __NNFW_CKER_TRAIN_OPERATION_RELU_H__
19
20
#include "
cker/Shape.h
"
21
#include "
cker/eigen/Utils.h
"
22
23
#include <Eigen/Core>
24
25
namespace
nnfw
26
{
27
namespace
cker
28
{
29
namespace
train
30
{
31
32
inline
void
ReLUGrad
(
const
Shape
&
output_shape
,
const
float
*output_data,
33
const
Shape
&incoming_shape,
const
float
*incoming_data,
34
const
Shape
&grad_shape,
float
*grad_data)
35
{
36
const
auto
output_map =
MapAsVector
(output_data,
output_shape
);
37
const
auto
incoming_map =
MapAsVector
(incoming_data, incoming_shape);
38
auto
grad_map =
MapAsVector
(grad_data, grad_shape);
39
40
if
(
output_shape
== incoming_shape &&
output_shape
== grad_shape)
41
grad_map.array() = incoming_map.array() * (output_map.array() > 0.0f).
template
cast<float>();
42
else
43
throw
std::runtime_error(
"cker::ReLUGrad: Unsupported shape"
);
44
}
45
46
}
// namespace train
47
}
// namespace cker
48
}
// namespace nnfw
49
50
#endif
// __NNFW_CKER_TRAIN_OPERATION_RELU_H__
nnfw::cker::Shape
Definition
Shape.h:32
Shape.h
Utils.h
output_shape
const luci_interpreter::RuntimeShape output_shape
Definition
PALComparisons.h:32
nnfw::cker::train::ReLUGrad
void ReLUGrad(const Shape &output_shape, const float *output_data, const Shape &incoming_shape, const float *incoming_data, const Shape &grad_shape, float *grad_data)
Definition
ReLU.h:32
nnfw::cker::MapAsVector
VectorMap< Scalar > MapAsVector(Scalar *data, const Shape &shape)
Definition
Utils.h:43
nnfw
Definition
topk_v2.h:30
compute
cker
include
cker
train
operation
ReLU.h
Generated by
1.9.8