ONE - On-device Neural Engine
Loading...
Searching...
No Matches
If.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "Builders.h"
17#include "kernels/Utils.h"
18
19#include <cstring>
20
21namespace luci_interpreter
22{
23
24void configure_kernel_CircleIf(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
25{
26 auto *main_runtime_graph = runtime_graph;
27
28 auto *runtime_module = runtime_graph->getRuntimeModule();
29
30 const auto *options = cur_op->builtin_options_as_IfOptions();
31
32 const auto cond_index = cur_op->inputs()->operator[](0);
33 const auto output_index = cur_op->outputs()->operator[](0);
34
35 const auto then_subgraph_index = options->then_subgraph_index();
36 const auto else_subgraph_index = options->else_subgraph_index();
37
38 assert(cond_index != -1);
39 assert(output_index != -1);
40
41 assert(then_subgraph_index != -1);
42 assert(else_subgraph_index != -1);
43
44 const auto cond = runtime_graph->getCircleTensorByIndex(cond_index);
45 LUCI_INTERPRETER_CHECK(Tensor::element_type(cond) == DataType::BOOL);
46 LUCI_INTERPRETER_CHECK(Tensor::num_elements(cond) == 1);
47
48 const auto output = runtime_graph->getCircleTensorByIndex(output_index);
49 auto *then_subgraph = runtime_module->getRuntimeGraphAt(then_subgraph_index);
50 auto *else_subgraph = runtime_module->getRuntimeGraphAt(else_subgraph_index);
51 for (RuntimeGraph *graph : {then_subgraph, else_subgraph})
52 {
53 graph->selectOwnSubgraph();
54 const auto graph_input_size = graph->getNumOfInputTensors();
55 const auto graph_output_size = graph->getNumOfOutputTensors();
56 LUCI_INTERPRETER_CHECK(graph_input_size == cur_op->inputs()->size() - 1);
57 LUCI_INTERPRETER_CHECK(graph_output_size == cur_op->outputs()->size());
58 graph->invalidate();
59 graph->configure(false);
60 }
61 main_runtime_graph->selectOwnSubgraph();
62}
63
64void execute_kernel_CircleIf(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
65{
66 auto *main_runtime_graph = runtime_graph;
67 auto *runtime_module = runtime_graph->getRuntimeModule();
68
69 const auto input_size = cur_op->inputs()->size() - 1;
70 const auto output_size = cur_op->outputs()->size();
71
72 std::vector<uint8_t *> operation_inputs_data(input_size);
73 std::vector<uint8_t *> operation_outputs_data(output_size);
74
75 std::vector<int32_t> input_sizes(input_size);
76 std::vector<int32_t> output_sizes(output_size);
77
78 const auto *options = cur_op->builtin_options_as_IfOptions();
79 const auto cond_index = cur_op->inputs()->operator[](0);
80
81 const auto then_subgraph_index = options->then_subgraph_index();
82 const auto else_subgraph_index = options->else_subgraph_index();
83
84 auto *then_subgraph = runtime_module->getRuntimeGraphAt(then_subgraph_index);
85 auto *else_subgraph = runtime_module->getRuntimeGraphAt(else_subgraph_index);
86
87 const auto cond = runtime_graph->getCircleTensorByIndex(cond_index);
88
89 const uint8_t *cond_data = runtime_graph->getDataByTensor(cond);
90 const bool cond_value = kernels::getTensorData<bool>(cond_data)[0];
91
92 RuntimeGraph *active_graph = cond_value ? then_subgraph : else_subgraph;
93
94 for (int32_t i = 0; i < input_size; ++i)
95 {
96 const auto op_input_index = cur_op->inputs()->operator[](i + 1);
97 assert(op_input_index != -1);
98 const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
99 input_sizes[i] = Tensor::num_elements(input) * size(Tensor::element_type(input));
100
101 auto *input_data = main_runtime_graph->getDataByTensor(input);
102
103 uint8_t *tensor_data = nullptr;
104 if (input_data == nullptr)
105 input_data = main_runtime_graph->getConstDataByTensor(input);
106 assert(input_data != nullptr);
107 tensor_data = main_runtime_graph->getDataByTensor(input);
108 assert(tensor_data != nullptr);
109
110 operation_inputs_data[i] = tensor_data;
111 }
112 for (int32_t i = 0; i < output_size; ++i)
113 {
114 const auto op_output_index = cur_op->outputs()->operator[](i);
115 assert(op_output_index != -1);
116 const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
117 output_sizes[i] = Tensor::num_elements(output) * size(Tensor::element_type(output));
118
119 auto *output_data = main_runtime_graph->getDataByTensor(output);
120
121 uint8_t *tensor_data = nullptr;
122 if (output_data == nullptr)
123 output_data = main_runtime_graph->getConstDataByTensor(output);
124 assert(output_data != nullptr);
125 tensor_data = main_runtime_graph->getDataByTensor(output);
126 assert(tensor_data != nullptr);
127
128 operation_outputs_data[i] = tensor_data;
129 }
130 active_graph->selectOwnSubgraph();
131 for (int32_t i = 0; i < input_size; ++i)
132 active_graph->configureGraphInput(i, operation_inputs_data[i]);
133 active_graph->execute();
134
135 for (int32_t i = 0; i < output_size; ++i)
136 {
137 auto cur_output_active_data = active_graph->getOutputDataByIndex(i);
138 if (cur_output_active_data == nullptr)
139 continue;
140 std::memcpy(operation_outputs_data[i], cur_output_active_data, output_sizes[i]);
141 }
142 active_graph->resetOutputTensorsData();
143 active_graph->clearTensors();
144 main_runtime_graph->selectOwnSubgraph();
145}
146
147} // namespace luci_interpreter
const circle::Tensor * getCircleTensorByIndex(int32_t index)
uint8_t * getOutputDataByIndex(int32_t output_index)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
RuntimeModule * getRuntimeModule()
uint8_t * configureGraphInput(int32_t input_index)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
void execute_kernel_CircleIf(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition If.cpp:64
void configure_kernel_CircleIf(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition If.cpp:24
int32_t size[5]
Definition Slice.cpp:35