ONE - On-device Neural Engine
Loading...
Searching...
No Matches
While.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "Builders.h"
19#include "kernels/Utils.h"
20
21#include <cstring>
22
23namespace luci_interpreter
24{
25
26void configure_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
27{
28 auto *main_runtime_graph = runtime_graph;
29
30 auto *runtime_module = runtime_graph->getRuntimeModule();
31
32 const auto *options = cur_op->builtin_options_as_WhileOptions();
33 const auto body_subgraph_index = options->body_subgraph_index();
34 const auto cond_subgraph_index = options->cond_subgraph_index();
35
36 auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
37 auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
38
39 body_runtime_graph->selectOwnSubgraph();
40 const auto body_input_size = body_runtime_graph->getNumOfInputTensors();
41 const auto body_output_size = body_runtime_graph->getNumOfOutputTensors();
42 LUCI_INTERPRETER_CHECK(body_input_size == cur_op->inputs()->size());
43 LUCI_INTERPRETER_CHECK(body_output_size == cur_op->outputs()->size());
44 LUCI_INTERPRETER_CHECK(body_output_size == cur_op->inputs()->size());
45 body_runtime_graph->invalidate();
46 body_runtime_graph->configure(false);
47
48 cond_runtime_graph->selectOwnSubgraph();
49 const auto cond_input_size = cond_runtime_graph->getNumOfInputTensors();
50 const auto cond_output_size = cond_runtime_graph->getNumOfOutputTensors();
51 LUCI_INTERPRETER_CHECK(cond_input_size == cur_op->inputs()->size());
52 LUCI_INTERPRETER_CHECK(cond_output_size == 1);
53 const circle::Tensor *cond_output_tensor = cond_runtime_graph->getOutputTensorByIndex(0);
54 LUCI_INTERPRETER_CHECK(Tensor::element_type(cond_output_tensor) == DataType::BOOL);
55 cond_runtime_graph->invalidate();
56 cond_runtime_graph->configure(false);
57
58 main_runtime_graph->selectOwnSubgraph();
59}
60
61void execute_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
62{
63 auto *main_runtime_graph = runtime_graph;
64 auto *runtime_module = runtime_graph->getRuntimeModule();
65
66 const auto input_size = cur_op->inputs()->size();
67
68 std::vector<uint8_t *> operation_inputs_data(input_size);
69 std::vector<uint8_t *> operation_outputs_data;
70
71 std::vector<int32_t> input_sizes(input_size);
72
73 bool is_inplace = runtime_graph->is_inplace_op(cur_op);
74
75 for (int32_t i = 0; i < input_size; ++i)
76 {
77 const auto op_input_index = cur_op->inputs()->operator[](i);
78 const auto op_output_index = cur_op->outputs()->operator[](i);
79 assert(op_input_index != -1);
80 assert(op_output_index != -1);
81 const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
82 const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
83
84 input_sizes[i] = Tensor::num_elements(input) * size(Tensor::element_type(input));
85
86 auto *input_data = main_runtime_graph->getDataByTensor(input);
87
88 uint8_t *tensor_data = nullptr;
89 if (is_inplace)
90 {
91 if (input_data == nullptr)
92 {
93 tensor_data = new uint8_t[input_sizes[i]];
94 input_data = main_runtime_graph->getConstDataByTensor(input);
95 assert(input_data != nullptr);
96 std::memcpy(tensor_data, input_data, input_sizes[i]);
97 }
98 else
99 {
100 tensor_data = input_data;
101 }
102 }
103 else
104 {
105 if (input_data == nullptr)
106 input_data = main_runtime_graph->getConstDataByTensor(input);
107 assert(input_data != nullptr);
108 tensor_data = main_runtime_graph->getDataByTensor(output);
109 assert(tensor_data != nullptr);
110 std::memcpy(tensor_data, input_data, input_sizes[i]);
111 }
112 assert(tensor_data != nullptr);
113
114 operation_inputs_data[i] = tensor_data;
115 }
116
117 const auto *options = cur_op->builtin_options_as_WhileOptions();
118 const auto body_subgraph_index = options->body_subgraph_index();
119 const auto cond_subgraph_index = options->cond_subgraph_index();
120
121 auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
122 auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
123
124 do
125 {
126 cond_runtime_graph->selectOwnSubgraph();
127
128 for (int32_t i = 0; i < input_size; ++i)
129 cond_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
130
131 cond_runtime_graph->execute();
132
133 bool cond_value = (cond_runtime_graph->getOutputDataByIndex(0))[0];
134 if (!cond_value)
135 break;
136
137 body_runtime_graph->selectOwnSubgraph();
138 for (int32_t i = 0; i < input_size; ++i)
139 body_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
140
141 body_runtime_graph->execute();
142
143 for (int32_t i = 0; i < input_size; ++i)
144 {
145 auto cur_output_body_data = body_runtime_graph->getOutputDataByIndex(i);
146 if (cur_output_body_data == nullptr)
147 continue;
148 std::memcpy(operation_inputs_data[i], cur_output_body_data, input_sizes[i]);
149 }
150 } while (true);
151
152 cond_runtime_graph->resetOutputTensorsData();
153 cond_runtime_graph->clearTensors();
154
155 body_runtime_graph->selectOwnSubgraph();
156 body_runtime_graph->resetOutputTensorsData();
157 body_runtime_graph->clearTensors();
158
159 main_runtime_graph->selectOwnSubgraph();
160
161 if (is_inplace)
162 {
163 for (int32_t i = 0; i < input_size; ++i)
164 {
165 const auto op_input_index = cur_op->inputs()->operator[](i);
166 const auto op_output_index = cur_op->outputs()->operator[](i);
167 assert(op_input_index != -1);
168 assert(op_output_index != -1);
169 const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
170 const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
171
172 if (main_runtime_graph->getDataByTensor(input))
173 {
174 main_runtime_graph->makeInplaceOperation(input, output);
175 }
176 else
177 {
178 main_runtime_graph->setDataToTensor(output, operation_inputs_data[i]);
179 }
180 }
181 }
182}
183
184} // namespace luci_interpreter
bool is_inplace_op(const circle::Operator *op)
RuntimeModule * getRuntimeModule()
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
void configure_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition While.cpp:26
void execute_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Definition While.cpp:61
int32_t size[5]
Definition Slice.cpp:35