ONE - On-device Neural Engine
Loading...
Searching...
No Matches
KernelGenerator.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "KernelGenerator.h"
18
19#include "kernel/CallLayer.h"
20#include "kernel/IfLayer.h"
21#include "kernel/PermuteLayer.h"
22#include "kernel/WhileLayer.h"
23
25
27{
28
30 const std::shared_ptr<TensorRegistry> &tensor_reg,
31 const std::shared_ptr<ExternalContext> &external_context)
32 : basic::KernelGeneratorBase{graph}, _dyn_tensor_manager{dyn_tensor_manager},
33 _tensor_reg{tensor_reg}, _tensor_registries{}, _executors{nullptr}, _model_index{},
34 _external_context{external_context}
35{
36 // DO NOTHING
37}
38
39std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind)
40{
41 assert(_dyn_tensor_manager);
42 assert(_tensor_reg);
43
44 auto ret = std::make_unique<exec::FunctionSequence>();
45
46 // Prepare to handle dynamic tensors later
47 auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
48 {
49 dyn_ctx->op = &_graph.operations().at(ind);
50 dyn_ctx->dynamic_shape_inferer = std::make_unique<exec::DynamicShapeInferer>(_tensor_reg);
51 }
52 ret->dynamic_tensor_ctx(dyn_ctx);
53
54 auto &op = _graph.operations().at(ind);
55 op.accept(*this);
56 assert(_return_fn); // _return_fn must have been generated
57 ret->append(std::move(_return_fn));
58
59 return ret;
60}
61
62void KernelGenerator::visit(const ir::operation::Custom &node)
63{
64 auto fill_op_info = [&](const ir::OperandIndexSequence &opSeq,
65 std::vector<custom::TypeInfo> &types,
66 std::vector<IPortableTensor *> &tensors) {
67 for (const auto &idx : opSeq)
68 {
69 const auto &operand = _graph.operands().at(idx);
70 types.emplace_back(custom::TypeInfo{operand.shape(), operand.typeInfo().type()});
71 auto in_tensor = _tensor_reg->getPortableTensor(idx);
72 tensors.emplace_back(in_tensor);
73 }
74 };
75
76 backend::custom::CustomKernelConfigParams params{};
77
78 fill_op_info(node.getInputs(), params.input_types, params.input_tensors);
79 fill_op_info(node.getOutputs(), params.output_types, params.output_tensors);
80
81 params.userdata = node.userdata().data;
82 params.userdata_size = node.userdata().size;
83
84 auto fn = _custom_kernel_builder->buildKernel(node.id(), std::move(params));
85
86 _return_fn = std::move(fn);
87}
88
89void KernelGenerator::visit(const ir::operation::Call &node)
90{
91 const auto callee_subg_index = node.param().callee_subg_index;
92
93 std::vector<backend::IPortableTensor *> input_tensors;
94 for (const auto &input_index : node.getInputs())
95 {
96 auto input_tensor = getPortableTensor(input_index);
97 input_tensors.emplace_back(input_tensor);
98 }
99
100 std::vector<backend::IPortableTensor *> output_tensors;
101 for (const auto &output_index : node.getOutputs())
102 {
103 auto output_tensor = getPortableTensor(output_index);
104 output_tensors.emplace_back(output_tensor);
105 }
106
107 auto fn = std::make_unique<::onert::backend::builtin::kernel::CallLayer>(
108 input_tensors, output_tensors, callee_subg_index, _executors, _model_index, _external_context);
109
110 _return_fn = std::move(fn);
111}
112
113void KernelGenerator::visit(const ir::operation::If &node)
114{
115 const auto then_subg_index = node.param().then_subg_index;
116 const auto else_subg_index = node.param().else_subg_index;
117
118 std::vector<backend::IPortableTensor *> input_tensors;
119 for (const auto &input_index : node.getInputs())
120 {
121 auto input_tensor = getPortableTensor(input_index);
122 input_tensors.emplace_back(input_tensor);
123 }
124
125 std::vector<backend::IPortableTensor *> output_tensors;
126 for (const auto &output_index : node.getOutputs())
127 {
128 auto output_tensor = getPortableTensor(output_index);
129 output_tensors.emplace_back(output_tensor);
130 }
131
132 // IfLayer just set Executors instead of then and else executor to avoid complexity of
133 // creating executor recusively
134 const auto cond_tensor = input_tensors.front();
135 input_tensors.erase(input_tensors.begin());
136 auto fn = std::make_unique<::onert::backend::builtin::kernel::IfLayer>(
137 cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executors,
138 _model_index, _external_context);
139
140 _return_fn = std::move(fn);
141}
142
143void KernelGenerator::visit(const ir::operation::Permute &node)
144{
145 const auto output_index{node.getOutputs().at(0)};
146 const auto input_index{node.getInputs().at(0)};
147
148 // Add PermuteLayer
149 std::vector<ITensor *> output_tensors{getTensor(output_index)};
150 std::vector<ITensor *> input_tensors{getTensor(input_index)};
151 std::vector<ir::PermuteType> permute_types{node.getPermuteType()};
152
153 auto fn = std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, permute_types,
154 _external_context);
155 _return_fn = std::move(fn);
156}
157
158void KernelGenerator::visit(const ir::operation::While &node)
159{
160 const auto cond_subg_index = node.param().cond_subg_index;
161 const auto body_subg_index = node.param().body_subg_index;
162
163 // This op does not support input as a constant, because builtin backend does not have
164 // TensorBuilder
165 std::vector<backend::IPortableTensor *> input_tensors;
166 for (const auto &input_index : node.getInputs())
167 {
168 auto input_tensor = getPortableTensor(input_index);
169 input_tensors.emplace_back(input_tensor);
170 }
171
172 std::vector<backend::IPortableTensor *> output_tensors;
173 for (const auto &output_index : node.getOutputs())
174 {
175 auto output_tensor = getPortableTensor(output_index);
176 output_tensors.emplace_back(output_tensor);
177 }
178
179 // WhileLayer just set Executors instead of cond and body executor to avoid complexity of
180 // creating executor recusively
181 auto fn = std::make_unique<::onert::backend::builtin::kernel::WhileLayer>(
182 input_tensors, output_tensors, cond_subg_index, body_subg_index, _executors, _model_index,
183 _dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context);
184
185 _return_fn = std::move(fn);
186}
187
188backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
189{
190 // get Tensor from all tensor registries (for Permute op)
191 auto ret = _tensor_registries.getITensor(index);
192 assert(ret != nullptr);
193 return ret;
194}
195
196backend::IPortableTensor *KernelGenerator::getPortableTensor(const ir::OperandIndex &index)
197{
198 auto ret = _tensor_reg->getPortableTensor(index);
199 assert(ret != nullptr);
200 return ret;
201}
202
203} // namespace onert::backend::builtin
Class to manage dynamic tensor and its memory.
std::shared_ptr< DynamicMemoryManager > dynamic_mem_mgr()
std::unique_ptr< exec::IFunction > _return_fn
KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager, const std::shared_ptr< TensorRegistry > &tensor_reg, const std::shared_ptr< ExternalContext > &external_context)
std::unique_ptr< exec::FunctionSequence > generate(ir::OperationIndex ind) override
backend::ITensor * getITensor(ir::OperandIndex ind) const
const Operands & operands() const override
Definition Graph.h:103
const Operations & operations() const override
Definition Graph.h:105
const OperandIndexSequence & getOutputs() const override
Definition Operation.h:54
OperandIndexSequence & getInputs()
Definition Operation.h:51
const Userdata & userdata() const
Definition Custom.cc:34
const std::string & id() const
Definition Custom.cc:32
const Object & at(const Index &index) const
Get the object that is associated with the given index.
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33