ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SelectV2.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2023 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Builders.h"
19#include "kernels/Utils.h"
20
21#include "PALSelectV2.h"
22
23namespace luci_interpreter
24{
25
26namespace
27{
28
29constexpr int kInputTensorCondition = 0;
30constexpr int kInputTensorX = 1;
31constexpr int kInputTensorY = 2;
32constexpr int kOutputTensor = 0;
33
34template <typename T>
35void CallSelect(const circle::Tensor *input_condition, const circle::Tensor *input_x,
36 const circle::Tensor *input_y, const circle::Tensor *output, bool need_broadcast,
37 RuntimeGraph *runtime_graph)
38{
39 using Func = decltype(luci_interpreter_pal::Select<bool, T>) *;
40 Func select_func;
41 if (need_broadcast)
42 {
43 assert(false && "Broadcast not supported now");
44 }
45 else
46 {
47 select_func = luci_interpreter_pal::Select<bool, T>;
48 }
49
50 select_func(kernels::getTensorRuntimeShape(input_condition, runtime_graph),
51 kernels::getTensorData<bool>(runtime_graph->getDataByTensor(input_condition)),
52 kernels::getTensorRuntimeShape(input_x, runtime_graph),
53 kernels::getTensorData<T>(runtime_graph->getDataByTensor(input_x)),
54 kernels::getTensorRuntimeShape(input_y, runtime_graph),
55 kernels::getTensorData<T>(runtime_graph->getDataByTensor(input_y)),
56 kernels::getTensorRuntimeShape(output, runtime_graph),
57 kernels::getTensorData<T>(runtime_graph->getDataByTensor(output)));
58}
59
60} // namespace
61
62void configure_kernel_CircleSelectV2(const circle::Operator *cur_op,
63 BaseRuntimeGraph *runtime_graph)
64{
65 const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition);
66 const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX);
67 const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY);
68 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
69
70 assert(input_cond_index != -1);
71 assert(input_x_index != -1);
72 assert(input_y_index != -1);
73 assert(output_index != -1);
74
75 const auto input_cond = runtime_graph->getCircleTensorByIndex(input_cond_index);
76 const auto input_x = runtime_graph->getCircleTensorByIndex(input_x_index);
77 const auto input_y = runtime_graph->getCircleTensorByIndex(input_y_index);
78 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
79
80 assert(input_cond != nullptr);
81 assert(input_x != nullptr);
82 assert(input_y != nullptr);
83
84 // Input condition should be bool
85 LUCI_INTERPRETER_CHECK(Tensor::element_type(input_cond) == DataType::BOOL);
86
87 // X, Y and Output should be the same type
88 LUCI_INTERPRETER_CHECK(Tensor::element_type(input_x) == Tensor::element_type(input_y));
89 LUCI_INTERPRETER_CHECK(Tensor::element_type(input_x) == Tensor::element_type(output));
90
91 bool possible_mixed_scaler =
92 Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 &&
93 Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1;
94
95 bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) &&
96 Tensor::num_elements(input_x) == Tensor::num_elements(input_y);
97
98 // Broadcast not supported now
99 if (not same_shape and not possible_mixed_scaler)
100 {
102 }
103}
104
105void execute_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
106{
107 const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition);
108 const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX);
109 const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY);
110 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
111
112 assert(input_cond_index != -1);
113 assert(input_x_index != -1);
114 assert(input_y_index != -1);
115 assert(output_index != -1);
116
117 const auto input_cond = runtime_graph->getCircleTensorByIndex(input_cond_index);
118 const auto input_x = runtime_graph->getCircleTensorByIndex(input_x_index);
119 const auto input_y = runtime_graph->getCircleTensorByIndex(input_y_index);
120 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
121
122 assert(input_cond != nullptr);
123 assert(input_x != nullptr);
124 assert(input_y != nullptr);
125
126 bool possible_mixed_scaler =
127 Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 &&
128 Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1;
129
130 bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) &&
131 Tensor::num_elements(input_x) == Tensor::num_elements(input_y);
132 bool is_broadcast = false;
133 if (not possible_mixed_scaler and not same_shape)
134 is_broadcast = true;
135
136 const auto type = Tensor::element_type(input_x);
137 switch (type)
138 {
139#ifndef DIS_FLOAT
140 case DataType::FLOAT32:
141 CallSelect<float>(input_cond, input_x, input_y, output, is_broadcast, runtime_graph);
142 break;
143#endif // DIS_FLOAT
144 default:
145 assert(false && "Unsupported type.");
146 }
147}
148
149} // namespace luci_interpreter
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
Definition Utils.cpp:29
void configure_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition SelectV2.cpp:62
void execute_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition SelectV2.cpp:105