ONE - On-device Neural Engine
Loading...
Searching...
No Matches
KernelGenerator.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "KernelGenerator.h"
18
19#include "kernel/IfLayer.h"
20#include "kernel/PermuteLayer.h"
21#include "kernel/WhileLayer.h"
22
24
25namespace onert
26{
27namespace backend
28{
29namespace builtin
30{
31
33 const std::shared_ptr<TensorRegistry> &tensor_reg,
34 const std::shared_ptr<ExternalContext> &external_context)
35 : basic::KernelGeneratorBase{graph}, _dyn_tensor_manager{dyn_tensor_manager},
36 _tensor_reg{tensor_reg}, _tensor_registries{}, _executors{nullptr}, _model_index{},
37 _external_context{external_context}
38{
39 // DO NOTHING
40}
41
42std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind)
43{
44 assert(_dyn_tensor_manager);
45 assert(_tensor_reg);
46
47 auto ret = std::make_unique<exec::FunctionSequence>();
48
49 // Prepare to handle dynamic tensors later
50 auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
51 {
52 dyn_ctx->op = &_graph.operations().at(ind);
53 dyn_ctx->dynamic_shape_inferer = std::make_unique<exec::DynamicShapeInferer>(_tensor_reg);
54 }
55 ret->dynamic_tensor_ctx(dyn_ctx);
56
57 auto &op = _graph.operations().at(ind);
58 op.accept(*this);
59 assert(_return_fn); // _return_fn must have been generated
60 ret->append(std::move(_return_fn));
61
62 return ret;
63}
64
65void KernelGenerator::visit(const ir::operation::If &node)
66{
67 const auto then_subg_index = node.param().then_subg_index;
68 const auto else_subg_index = node.param().else_subg_index;
69
70 std::vector<backend::IPortableTensor *> input_tensors;
71 for (const auto &input_index : node.getInputs())
72 {
73 auto input_tensor = getPortableTensor(input_index);
74 input_tensors.emplace_back(input_tensor);
75 }
76
77 std::vector<backend::IPortableTensor *> output_tensors;
78 for (const auto &output_index : node.getOutputs())
79 {
80 auto output_tensor = getPortableTensor(output_index);
81 output_tensors.emplace_back(output_tensor);
82 }
83
84 // IfLayer just set Executors instead of then and else executor to avoid complexity of
85 // creating executor recusively
86 const auto cond_tensor = input_tensors.front();
87 input_tensors.erase(input_tensors.begin());
88 auto fn = std::make_unique<::onert::backend::builtin::kernel::IfLayer>(
89 cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executors,
90 _model_index, _external_context);
91
92 _return_fn = std::move(fn);
93}
94
95void KernelGenerator::visit(const ir::operation::Permute &node)
96{
97 const auto output_index{node.getOutputs().at(0)};
98 const auto input_index{node.getInputs().at(0)};
99
100 // Add PermuteLayer
101 std::vector<ITensor *> output_tensors{getTensor(output_index)};
102 std::vector<ITensor *> input_tensors{getTensor(input_index)};
103 std::vector<ir::PermuteType> permute_types;
104
105 // Layout in graph is always NHWC, so layout is not changed
106 for (uint32_t i = 0; i < input_tensors.size(); i++)
107 permute_types.emplace_back(ir::PermuteType::COPY);
108
109 auto fn = std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, permute_types,
110 _external_context);
111 _return_fn = std::move(fn);
112}
113
114void KernelGenerator::visit(const ir::operation::While &node)
115{
116 const auto cond_subg_index = node.param().cond_subg_index;
117 const auto body_subg_index = node.param().body_subg_index;
118
119 // This op does not support input as a constant, because builtin backend does not have
120 // TensorBuilder
121 std::vector<backend::IPortableTensor *> input_tensors;
122 for (const auto &input_index : node.getInputs())
123 {
124 auto input_tensor = getPortableTensor(input_index);
125 input_tensors.emplace_back(input_tensor);
126 }
127
128 std::vector<backend::IPortableTensor *> output_tensors;
129 for (const auto &output_index : node.getOutputs())
130 {
131 auto output_tensor = getPortableTensor(output_index);
132 output_tensors.emplace_back(output_tensor);
133 }
134
135 // WhileLayer just set Executors instead of cond and body executor to avoid complexity of
136 // creating executor recusively
137 auto fn = std::make_unique<::onert::backend::builtin::kernel::WhileLayer>(
138 input_tensors, output_tensors, cond_subg_index, body_subg_index, _executors, _model_index,
139 _dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context);
140
141 _return_fn = std::move(fn);
142}
143
144backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
145{
146 // get Tensor from all tensor registries (for Permute op)
147 auto ret = _tensor_registries.getITensor(index);
148 assert(ret != nullptr);
149 return ret;
150}
151
152backend::IPortableTensor *KernelGenerator::getPortableTensor(const ir::OperandIndex &index)
153{
154 auto ret = _tensor_reg->getPortableTensor(index);
155 assert(ret != nullptr);
156 return ret;
157}
158
159} // namespace builtin
160} // namespace backend
161} // namespace onert
Class to manage dynamic tensor and its memory.
std::shared_ptr< DynamicMemoryManager > dynamic_mem_mgr()
std::unique_ptr< exec::IFunction > _return_fn
KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager, const std::shared_ptr< TensorRegistry > &tensor_reg, const std::shared_ptr< ExternalContext > &external_context)
std::unique_ptr< exec::FunctionSequence > generate(ir::OperationIndex ind) override
backend::ITensor * getITensor(ir::OperandIndex ind) const
const Operations & operations() const override
Definition Graph.h:114
const Param & param() const
Definition If.h:47
const Object & at(const Index &index) const
Get the object that is associated with the given index.
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35
SubgraphIndex then_subg_index
Definition If.h:35
SubgraphIndex else_subg_index
Definition If.h:36