ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Sub.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "Builders.h"
18#include "kernels/Utils.h"
19
20#include "kernels/BinaryOpCommon.h"
21
22#include "PALSub.h"
23
24namespace luci_interpreter
25{
26
27void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
28{
29 kernels::TISOKernel kernel(cur_op, runtime_graph);
30
31 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
32 Tensor::element_type(kernel.input2()));
33 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
34 Tensor::element_type(kernel.input2()));
35#ifndef DIS_QUANT
36 if (Tensor::element_type(kernel.input1()) == DataType::S16)
37 {
38 LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
39 Tensor::zero_points(kernel.input2()).size() == 1);
40 LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
41 Tensor::zero_point(kernel.input2()) == 0 &&
42 Tensor::zero_point(kernel.output()) == 0);
43 }
44#endif // DIS_QUANT
45}
46
47void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
48{
49 kernels::TISOKernel kernel(cur_op, runtime_graph);
50
51 const auto *options = cur_op->builtin_options_as_SubOptions();
52
54 kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
56 kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
57
59 kernels::getTensorRuntimeShape(kernel.output(), runtime_graph);
60
61 bool is_inplace = runtime_graph->is_inplace_op(cur_op);
62
63 switch (Tensor::element_type(kernel.input1()))
64 {
65#ifndef DIS_FLOAT
66 case DataType::FLOAT32:
67 {
68 auto tiso_func = luci_interpreter_pal::Sub<float>;
69
70 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<float>;
71 if (is_inplace)
72 {
73 kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
74 std::move(input_shape1), std::move(input_shape2),
75 std::move(output_shape));
76 }
77 else
78 {
79 kernels::TISOData kernel_data = kernel.readData();
80 kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
81 options, std::move(input_shape1), std::move(input_shape2),
82 std::move(output_shape));
83 }
84 }
85 break;
86#endif // DIS_FLOAT
87 case DataType::S64:
88 {
89 auto tiso_func = luci_interpreter_pal::Sub<int64_t>;
90
91 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int64_t>;
92
93 if (is_inplace)
94 {
95 kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
96 std::move(input_shape1), std::move(input_shape2),
97 std::move(output_shape));
98 }
99 else
100 {
101 kernels::TISOData kernel_data = kernel.readData();
102 kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
103 options, std::move(input_shape1), std::move(input_shape2),
104 std::move(output_shape));
105 }
106 }
107 break;
108 case DataType::S32:
109 {
110 auto tiso_func = luci_interpreter_pal::Sub<int32_t>;
111
112 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int32_t>;
113
114 if (is_inplace)
115 {
116 kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
117 std::move(input_shape1), std::move(input_shape2),
118 std::move(output_shape));
119 }
120 else
121 {
122 kernels::TISOData kernel_data = kernel.readData();
123 kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
124 options, std::move(input_shape1), std::move(input_shape2),
125 std::move(output_shape));
126 }
127 }
128 break;
129// TODO: fix it
130#if 0
131#ifndef DIS_QUANT
132 case DataType::U8:
133 {
134 auto tiso_func = [](const tflite::ArithmeticParams &params,
135 const tflite::RuntimeShape &input1_shape, const uint8_t *input1_data,
136 const tflite::RuntimeShape &input2_shape, const uint8_t *input2_data,
137 const tflite::RuntimeShape &output_shape, uint8_t *output_data) {
138 tflite::reference_ops::Sub(params, input1_shape, input1_data, input2_shape, input2_data,
139 output_shape, output_data);
140 };
141 auto broadcast_tiso_func =
142 [](const tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
143 const uint8_t *input1_data, const tflite::RuntimeShape &input2_shape,
144 const uint8_t *input2_data, const tflite::RuntimeShape &output_shape,
145 uint8_t *output_data) {
146 tflite::reference_ops::BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
147 input2_data, output_shape, output_data);
148 };
149 if (is_inplace)
150 {
151 kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
152 options);
153 }
154 else
155 {
156 kernels::TISOData kernel_data = kernel.readData();
157 kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
158 &kernel_data, options);
159 }
160 }
161 break;
162#endif // DIS_QUANT
163#endif // 0
164 default:
165 assert(false && "Unsupported type.");
166 }
167}
168
169} // namespace luci_interpreter
bool is_inplace_op(const circle::Operator *op)
const circle::Tensor * output() const
Definition TISOKernel.h:62
const circle::Tensor * input2() const
Definition TISOKernel.h:61
const circle::Tensor * input1() const
Definition TISOKernel.h:60
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
Definition Utils.cpp:29
void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Sub.cpp:47
void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition Sub.cpp:27