ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
While.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "OMStatus.h"
20#include "core/OMUtils.h"
21#include "core/OMDataType.h"
23
24using namespace onert_micro;
25using namespace onert_micro::execute;
26
27// NOTE: doesnt currently support dynamic shapes
28namespace onert_micro
29{
30namespace execute
31{
32
34{
35 core::OMRuntimeModule &runtime_module = execute_args.runtime_module;
36 core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
37 core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
38 uint16_t op_index = execute_args.kernel_index;
39
40 OMRuntimeKernel runtime_kernel;
41 runtime_kernel.readKernel(op_index, runtime_context);
42 auto options = runtime_kernel.first_operator->builtin_options_as_WhileOptions();
43
44 // Obtain conditional and body runtime subgraphs
45 const auto body_subgraph_index = options->body_subgraph_index();
46 const auto cond_subgraph_index = options->cond_subgraph_index();
47 core::OMRuntimeGraph *cond_runtime_graph = nullptr;
48 core::OMRuntimeGraph *body_runtime_graph = nullptr;
49 runtime_module.getRuntimeGraphAt(cond_subgraph_index, &cond_runtime_graph);
50 runtime_module.getRuntimeGraphAt(body_subgraph_index, &body_runtime_graph);
51
52 core::OMRuntimeContext &cond_runtime_context = cond_runtime_graph->getRuntimeContext();
53 core::OMRuntimeStorage &cond_runtime_storage = cond_runtime_graph->getRuntimeStorage();
54 core::memory::OMRuntimeAllocator &cond_runtime_allocator =
55 cond_runtime_graph->getRuntimeAllocator();
56
57 core::OMRuntimeContext &body_runtime_context = body_runtime_graph->getRuntimeContext();
58 core::OMRuntimeStorage &body_runtime_storage = body_runtime_graph->getRuntimeStorage();
59 core::memory::OMRuntimeAllocator &body_runtime_allocator =
60 body_runtime_graph->getRuntimeAllocator();
61
62 OMStatus status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
63 if (status != Ok)
64 return status;
65
66 // Copy input data to the output
67 assert(runtime_kernel.inputs_num == runtime_kernel.outputs_num);
68 for (uint32_t i = 0; i < runtime_kernel.inputs_num; ++i)
69 {
70 const auto cur_input_tensor = runtime_kernel.inputs[i];
71 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
72 core::OMRuntimeShape(cur_input_tensor).flatSize();
73 std::memcpy(runtime_kernel.outputs_data[i], runtime_kernel.inputs_data[i], input_data_size);
74 }
75
76 do
77 {
78 // Handle conditional graph
79 {
80 // Allocate cond graph inputs
81 cond_runtime_graph->allocateGraphInputs();
82 auto cond_graphs_inputs = cond_runtime_graph->getNumberOfInputs();
83 for (uint32_t i = 0; i < cond_graphs_inputs; ++i)
84 {
85 auto *cur_cond_input_data =
86 reinterpret_cast<uint8_t *>(cond_runtime_graph->getInputDataAt(i));
87 uint8_t *cur_main_input_data = runtime_kernel.outputs_data[i];
88 assert(cur_main_input_data != nullptr);
89 assert(cur_cond_input_data != nullptr);
90 const auto cur_input_tensor = runtime_kernel.inputs[i];
91 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
92 core::OMRuntimeShape(cur_input_tensor).flatSize();
93 std::memcpy(cur_cond_input_data, cur_main_input_data, input_data_size);
94 }
95 // Run cond graph
96 execute::OMExecuteArgs cond_execute_args = {cond_runtime_storage, cond_runtime_context, 0,
97 runtime_module};
98 status = execute::OMKernelExecute::runForward(cond_execute_args, cond_runtime_allocator);
99 if (status != Ok)
100 return status;
101
102 // Check cond graph result
103 bool cond_result_value = reinterpret_cast<bool *>(cond_runtime_graph->getOutputDataAt(0))[0];
104 // Reset cond graph values
105 cond_runtime_graph->reset();
106 // If false - then finish while loop
107 if (cond_result_value == false)
108 break;
109 }
110
111 // Handle body graph
112 {
113 // Allocate body graph inputs
114 body_runtime_graph->allocateGraphInputs();
115 // Copy data
116 auto body_graphs_inputs = body_runtime_graph->getNumberOfInputs();
117 for (uint32_t i = 0; i < body_graphs_inputs; ++i)
118 {
119 auto *cur_body_input_data =
120 reinterpret_cast<uint8_t *>(body_runtime_graph->getInputDataAt(i));
121 uint8_t *cur_main_input_data = runtime_kernel.outputs_data[i];
122 assert(cur_main_input_data != nullptr);
123 assert(cur_body_input_data != nullptr);
124 const auto cur_input_tensor = runtime_kernel.inputs[i];
125 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
126 core::OMRuntimeShape(cur_input_tensor).flatSize();
127 std::memcpy(cur_body_input_data, cur_main_input_data, input_data_size);
128 }
129 // Run body graph
130 execute::OMExecuteArgs body_execute_args = {body_runtime_storage, body_runtime_context, 0,
131 runtime_module};
132 status = execute::OMKernelExecute::runForward(body_execute_args, body_runtime_allocator);
133 if (status != Ok)
134 return status;
135
136 // Copy body calculated data to the main output
137 for (uint32_t i = 0; i < runtime_kernel.inputs_num; ++i)
138 {
139 auto cur_calculated_data = body_runtime_graph->getOutputDataAt(i);
140 const auto cur_tensor = runtime_kernel.outputs[i];
141 const auto data_size = sizeof(core::OMDataType(cur_tensor->type())) *
142 core::OMRuntimeShape(cur_tensor).flatSize();
143 std::memcpy(runtime_kernel.outputs_data[i], cur_calculated_data, data_size);
144 }
145
146 body_runtime_graph->reset();
147 }
148 } while (true);
149
150 return status;
151}
152
153} // namespace execute
154} // namespace onert_micro
memory::OMRuntimeAllocator & getRuntimeAllocator()
OMRuntimeContext & getRuntimeContext()
void * getInputDataAt(uint32_t position)
OMRuntimeStorage & getRuntimeStorage()
void * getOutputDataAt(uint32_t position)
OMStatus getRuntimeGraphAt(uint32_t pos, OMRuntimeGraph **runtime_graph)
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
OMDataType
"scalar" value type
Definition OMDataType.h:35
OMStatus execute_kernel_CircleWhile(const OMExecuteArgs &execute_args)
Definition While.cpp:33
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage
core::OMRuntimeModule & runtime_module
static OMStatus runForward(OMExecuteArgs &, core::memory::OMRuntimeAllocator &allocator)