ONE - On-device Neural Engine
Loading...
Searching...
No Matches
While.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18#include "OMStatus.h"
20#include "core/OMUtils.h"
21#include "core/OMDataType.h"
23
24using namespace onert_micro;
25using namespace onert_micro::execute;
26
27// NOTE: doesnt currently support dynamic shapes
28OMStatus onert_micro::execute::execute_kernel_CircleWhile(const OMExecuteArgs &execute_args)
29{
30 core::OMRuntimeModule &runtime_module = execute_args.runtime_module;
31 core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
32 core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
33 uint16_t op_index = execute_args.kernel_index;
34
35 OMRuntimeKernel runtime_kernel;
36 runtime_kernel.readKernel(op_index, runtime_context);
37 auto options = runtime_kernel.first_operator->builtin_options_as_WhileOptions();
38
39 // Obtain conditional and body runtime subgraphs
40 const auto body_subgraph_index = options->body_subgraph_index();
41 const auto cond_subgraph_index = options->cond_subgraph_index();
42 core::OMRuntimeGraph *cond_runtime_graph = nullptr;
43 core::OMRuntimeGraph *body_runtime_graph = nullptr;
44 runtime_module.getRuntimeGraphAt(cond_subgraph_index, &cond_runtime_graph);
45 runtime_module.getRuntimeGraphAt(body_subgraph_index, &body_runtime_graph);
46
47 core::OMRuntimeContext &cond_runtime_context = cond_runtime_graph->getRuntimeContext();
48 core::OMRuntimeStorage &cond_runtime_storage = cond_runtime_graph->getRuntimeStorage();
49 core::memory::OMRuntimeAllocator &cond_runtime_allocator =
50 cond_runtime_graph->getRuntimeAllocator();
51
52 core::OMRuntimeContext &body_runtime_context = body_runtime_graph->getRuntimeContext();
53 core::OMRuntimeStorage &body_runtime_storage = body_runtime_graph->getRuntimeStorage();
54 core::memory::OMRuntimeAllocator &body_runtime_allocator =
55 body_runtime_graph->getRuntimeAllocator();
56
57 OMStatus status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
58 if (status != Ok)
59 return status;
60
61 // Copy input data to the output
62 assert(runtime_kernel.inputs_num == runtime_kernel.outputs_num);
63 for (uint32_t i = 0; i < runtime_kernel.inputs_num; ++i)
64 {
65 const auto cur_input_tensor = runtime_kernel.inputs[i];
66 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
67 core::OMRuntimeShape(cur_input_tensor).flatSize();
68 std::memcpy(runtime_kernel.outputs_data[i], runtime_kernel.inputs_data[i], input_data_size);
69 }
70
71 do
72 {
73 // Handle conditional graph
74 {
75 // Allocate cond graph inputs
76 cond_runtime_graph->allocateGraphInputs();
77 auto cond_graphs_inputs = cond_runtime_graph->getNumberOfInputs();
78 for (uint32_t i = 0; i < cond_graphs_inputs; ++i)
79 {
80 auto *cur_cond_input_data =
81 reinterpret_cast<uint8_t *>(cond_runtime_graph->getInputDataAt(i));
82 uint8_t *cur_main_input_data = runtime_kernel.outputs_data[i];
83 assert(cur_main_input_data != nullptr);
84 assert(cur_cond_input_data != nullptr);
85 const auto cur_input_tensor = runtime_kernel.inputs[i];
86 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
87 core::OMRuntimeShape(cur_input_tensor).flatSize();
88 std::memcpy(cur_cond_input_data, cur_main_input_data, input_data_size);
89 }
90 // Run cond graph
91 execute::OMExecuteArgs cond_execute_args = {cond_runtime_storage, cond_runtime_context, 0,
92 runtime_module};
93 status = execute::OMKernelExecute::runForward(cond_execute_args, cond_runtime_allocator);
94 if (status != Ok)
95 return status;
96
97 // Check cond graph result
98 bool cond_result_value = reinterpret_cast<bool *>(cond_runtime_graph->getOutputDataAt(0))[0];
99 // Reset cond graph values
100 cond_runtime_graph->reset();
101 // If false - then finish while loop
102 if (cond_result_value == false)
103 break;
104 }
105
106 // Handle body graph
107 {
108 // Allocate body graph inputs
109 body_runtime_graph->allocateGraphInputs();
110 // Copy data
111 auto body_graphs_inputs = body_runtime_graph->getNumberOfInputs();
112 for (uint32_t i = 0; i < body_graphs_inputs; ++i)
113 {
114 auto *cur_body_input_data =
115 reinterpret_cast<uint8_t *>(body_runtime_graph->getInputDataAt(i));
116 uint8_t *cur_main_input_data = runtime_kernel.outputs_data[i];
117 assert(cur_main_input_data != nullptr);
118 assert(cur_body_input_data != nullptr);
119 const auto cur_input_tensor = runtime_kernel.inputs[i];
120 const auto input_data_size = sizeof(core::OMDataType(cur_input_tensor->type())) *
121 core::OMRuntimeShape(cur_input_tensor).flatSize();
122 std::memcpy(cur_body_input_data, cur_main_input_data, input_data_size);
123 }
124 // Run body graph
125 execute::OMExecuteArgs body_execute_args = {body_runtime_storage, body_runtime_context, 0,
126 runtime_module};
127 status = execute::OMKernelExecute::runForward(body_execute_args, body_runtime_allocator);
128 if (status != Ok)
129 return status;
130
131 // Copy body calculated data to the main output
132 for (uint32_t i = 0; i < runtime_kernel.inputs_num; ++i)
133 {
134 auto cur_calculated_data = body_runtime_graph->getOutputDataAt(i);
135 const auto cur_tensor = runtime_kernel.outputs[i];
136 const auto data_size = sizeof(core::OMDataType(cur_tensor->type())) *
137 core::OMRuntimeShape(cur_tensor).flatSize();
138 std::memcpy(runtime_kernel.outputs_data[i], cur_calculated_data, data_size);
139 }
140
141 body_runtime_graph->reset();
142 }
143 } while (true);
144
145 return status;
146}
memory::OMRuntimeAllocator & getRuntimeAllocator()
OMRuntimeContext & getRuntimeContext()
void * getInputDataAt(uint32_t position)
OMRuntimeStorage & getRuntimeStorage()
void * getOutputDataAt(uint32_t position)
OMStatus getRuntimeGraphAt(uint32_t pos, OMRuntimeGraph **runtime_graph)
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
OMDataType
"scalar" value type
Definition OMDataType.h:35
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage
core::OMRuntimeModule & runtime_module
static OMStatus runForward(OMExecuteArgs &, core::memory::OMRuntimeAllocator &allocator)