ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Relu.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "kernels/Relu.h"
18#include "kernels/Utils.h"
19
20#include "PALRelu.h"
21
22#include <stdexcept>
23
24namespace luci_interpreter
25{
26
27namespace kernels
28{
29
30Relu::Relu(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
31
33{
34 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
35 if (input()->element_type() == DataType::S16)
36 {
37 LUCI_INTERPRETER_CHECK(input()->zero_point() == 0 && output()->zero_point() == 0);
38 }
39
40 if (input()->element_type() == DataType::U8 || input()->element_type() == DataType::S16)
41 {
42 double multiplier = input()->scale() / output()->scale();
43 quantizeMultiplier(multiplier, &_output_multiplier, &_output_shift);
44 }
45 output()->resize(input()->shape());
46}
47
48void Relu::execute() const
49{
50 switch (input()->element_type())
51 {
52 case DataType::FLOAT32:
53 evalFloat();
54 break;
55 case DataType::U8:
56 evalQuantized();
57 break;
58 case DataType::S16:
59 evalQuantizedS16();
60 break;
61 default:
62 throw std::runtime_error("luci-intp Relu Unsupported type.");
63 }
64}
65
66void Relu::evalFloat() const
67{
68 const auto input_data = getTensorData<float>(input());
69 const auto input_shape = getTensorShape(input());
70 auto output_data = getTensorData<float>(output());
72
73 luci_interpreter_pal::Relu(input_shape, input_data, output_shape, output_data);
74}
75
76void Relu::evalQuantized() const
77{
78 tflite::ReluParams params;
79 params.input_offset = input()->zero_point();
80 params.output_offset = output()->zero_point();
81 params.output_multiplier = _output_multiplier;
82 params.output_shift = _output_shift;
83
84 params.quantized_activation_min =
85 std::max(static_cast<int32_t>(std::numeric_limits<uint8_t>::min()), params.output_offset);
86 params.quantized_activation_max = static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
87
88 luci_interpreter_pal::ReluX(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
89 getTensorShape(output()), getTensorData<uint8_t>(output()));
90}
91
92void Relu::evalQuantizedS16() const
93{
94 const auto *input_data = getTensorData<int16_t>(input());
95 auto *output_data = getTensorData<int16_t>(output());
96
97 constexpr int32_t output_min = 0;
98 constexpr int32_t output_max = std::numeric_limits<int16_t>::max();
99
100 const int32_t num_elements = input()->shape().num_elements();
101
102 for (int32_t i = 0; i < num_elements; ++i)
103 {
104 const int32_t input_val = input_data[i];
105 int32_t output_val =
106 tflite::MultiplyByQuantizedMultiplier(input_val, _output_multiplier, _output_shift);
107 output_val = std::max(output_val, output_min);
108 output_val = std::min(output_val, output_max);
109 output_data[i] = static_cast<int16_t>(output_val);
110 }
111}
112
113} // namespace kernels
114} // namespace luci_interpreter
int32_t num_elements() const
Definition Tensor.h:53
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
float scale() const
Definition Tensor.h:109
int32_t zero_point() const
Definition Tensor.h:115
Tensor * output() const
Definition Relu.h:33
Relu(const Tensor *input, Tensor *output)
Definition Relu.cpp:30
void configure() override
Definition Relu.cpp:32
const Tensor * input() const
Definition Relu.h:32
void execute() const override
Definition Relu.cpp:48
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Definition Utils.cpp:157