ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Mul.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/Mul.h"
19
20#include "kernels/BinaryOpCommon.h"
21#include "kernels/Utils.h"
22
23#include "PALMul.h"
24
25#include <tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h>
26
27#include <stdexcept>
28
29namespace luci_interpreter
30{
31namespace kernels
32{
33
34Mul::Mul(const Tensor *input1, const Tensor *input2, Tensor *output, const MulParams &params)
35 : KernelWithParams<MulParams>({input1, input2}, {output}, params)
36{
37}
38
40{
41 LUCI_INTERPRETER_CHECK(input1()->element_type() == input2()->element_type());
42 LUCI_INTERPRETER_CHECK(output()->element_type() == input1()->element_type());
43 if (input1()->element_type() == DataType::S16)
44 {
45 LUCI_INTERPRETER_CHECK(input1()->zero_points().size() == 1 &&
46 input2()->zero_points().size() == 1)
47 LUCI_INTERPRETER_CHECK(input1()->zero_point() == 0 && input2()->zero_point() == 0 &&
48 output()->zero_point() == 0);
49 }
50
51 output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
52}
53
54void Mul::execute() const
55{
56 switch (input1()->element_type())
57 {
58 case DataType::FLOAT32:
59 evalFloat();
60 break;
61 case DataType::S64:
62 evalInteger<int64_t>();
63 break;
64 case DataType::S32:
65 evalInteger<int32_t>();
66 break;
67 case DataType::S16:
68 evalQuantizedS16();
69 break;
70 default:
71 throw std::runtime_error("luci-intp Mul Unsupported type.");
72 }
73}
74
75void Mul::evalFloat() const
76{
77 tflite::ArithmeticParams params{};
78 fillArithmeticActivationRange<float>(params, _params.activation);
79
80 const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
82
83 if (need_broadcast)
84 {
85 luci_interpreter_pal::BroadcastMul4DSlow(
86 params, getTensorShape(input1()), getTensorData<float>(input1()), getTensorShape(input2()),
87 getTensorData<float>(input2()), getTensorShape(output()), getTensorData<float>(output()));
88 }
89 else
90 {
91 luci_interpreter_pal::Mul(params, getTensorShape(input1()), getTensorData<float>(input1()),
92 getTensorShape(input2()), getTensorData<float>(input2()),
93 getTensorShape(output()), getTensorData<float>(output()));
94 }
95}
96
97template <typename T> void Mul::evalInteger() const
98{
99 tflite::ArithmeticParams params{};
100 fillArithmeticActivationRange<T>(params, _params.activation);
101
102 const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
104
105 if (need_broadcast)
106 {
107 luci_interpreter_pal::BroadcastMul4DSlow(
108 params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
109 getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
110 }
111 else
112 {
113 luci_interpreter_pal::Mul(params, getTensorShape(input1()), getTensorData<T>(input1()),
114 getTensorShape(input2()), getTensorData<T>(input2()),
115 getTensorShape(output()), getTensorData<T>(output()));
116 }
117}
118
119void Mul::evalQuantizedS16() const
120{
121 const auto input1_scale = static_cast<double>(input1()->scale());
122 const auto input2_scale = static_cast<double>(input2()->scale());
123 const auto output_scale = static_cast<double>(output()->scale());
124
125 const double real_multiplier = input1_scale * input2_scale / output_scale;
126
127 int32_t output_multiplier;
128 int output_shift;
129 quantizeMultiplier(real_multiplier, &output_multiplier, &output_shift);
130
131 int32_t activation_min{};
132 int32_t activation_max{};
133 calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
134
135 auto fn = [output_multiplier, output_shift, activation_min, activation_max](int16_t input1_val,
136 int16_t input2_val) {
137 int32_t output = static_cast<int32_t>(input1_val) * static_cast<int32_t>(input2_val);
138 output = tflite::MultiplyByQuantizedMultiplier(output, output_multiplier, output_shift);
139 output = std::max(output, activation_min);
140 output = std::min(output, activation_max);
141 return static_cast<int16_t>(output);
142 };
143
144 BinaryOpBroadcastSlow(getTensorShape(input1()), getTensorData<int16_t>(input1()),
145 getTensorShape(input2()), getTensorData<int16_t>(input2()),
146 getTensorShape(output()), getTensorData<int16_t>(output()), fn);
147}
148
149} // namespace kernels
150} // namespace luci_interpreter
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
float scale() const
Definition Tensor.h:109
Tensor * output() const
Definition Mul.h:38
void configure() override
Definition Mul.cpp:39
const Tensor * input1() const
Definition Mul.h:36
void execute() const override
Definition Mul.cpp:54
const Tensor * input2() const
Definition Mul.h:37
Mul(const Tensor *input1, const Tensor *input2, Tensor *output, const MulParams &params)
Definition Mul.cpp:34
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
Shape calculateShapeForBroadcast(const Shape &input1_shape, const Shape &input2_shape)
Definition Utils.cpp:204
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void calculateActivationRangeQuantized(Activation activation, const Tensor *output, int32_t *activation_min, int32_t *activation_max)
Definition Utils.cpp:119
void quantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift)
Definition Utils.cpp:157
void BinaryOpBroadcastSlow(const tflite::RuntimeShape &unextended_input1_shape, const T *input1_data, const tflite::RuntimeShape &unextended_input2_shape, const T *input2_data, const tflite::RuntimeShape &unextended_output_shape, T *output_data, Op op)
int32_t size[5]
Definition Slice.cpp:35