ONE - On-device Neural Engine
Loading...
Searching...
No Matches
KernelBuilder.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_KERNEL_KERNELBUILDER_H
18#define LUCI_INTERPRETER_KERNEL_KERNELBUILDER_H
19
20#include "core/RuntimeModule.h"
22#include "Builders.h"
23
24#include <memory>
25#include <unordered_map>
26
27namespace luci_interpreter
28{
29#define REGISTER_KERNEL(builtin_operator, name) BuiltinOperator_##builtin_operator,
30
31enum class BuilderID
32{
33#if USE_GENERATED_LIST
34#include "GeneratedKernelsToBuild.lst"
35#else
36#include "KernelsToBuild.lst"
37#endif
38 Size // casts to count of values in BuilderId enum
39};
40
41#undef REGISTER_KERNEL
42
43constexpr BuilderID get_builder_id(circle::BuiltinOperator opcode)
44{
45 switch (opcode)
46 {
47#define REGISTER_KERNEL(builtin_operator, name) \
48 case circle::BuiltinOperator_##builtin_operator: \
49 return BuilderID::BuiltinOperator_##builtin_operator;
50
51#if USE_GENERATED_LIST
52#include "GeneratedKernelsToBuild.lst"
53#else
54#include "KernelsToBuild.lst"
55#endif
56
57#undef REGISTER_KERNEL
58 default:
59 assert(false && "Unsupported operation");
60 }
61}
62
64{
65public:
66 using KernelConfigureFunc = void(const circle::Operator *, BaseRuntimeGraph *);
67
68 constexpr KernelConfigureRegistry() : _operator_configure()
69 {
70#define REGISTER_KERNEL(builtin_operator, name) \
71 register_kernel_configure(BuilderID::BuiltinOperator_##builtin_operator, \
72 configure_kernel_Circle##name);
73
74#if USE_GENERATED_LIST
75#include "GeneratedKernelsToBuild.lst"
76#else
77#include "KernelsToBuild.lst"
78#endif
79
80#undef REGISTER_KERNEL
81 }
82
83 void configure_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode,
84 BaseRuntimeGraph *runtime_graph) const;
85
86private:
87 constexpr KernelConfigureFunc *get_kernel_configure_func(circle::BuiltinOperator opcode) const
88 {
89 const auto builder_id_opcode = size_t(get_builder_id(opcode));
90 assert(builder_id_opcode < size_t(BuilderID::Size));
91 return _operator_configure[builder_id_opcode];
92 }
93
94 constexpr void register_kernel_configure(BuilderID id, KernelConfigureFunc *func)
95 {
96 assert(size_t(id) < size_t(BuilderID::Size));
97 _operator_configure[size_t(id)] = func;
98 }
99
100private:
101 KernelConfigureFunc *_operator_configure[size_t(BuilderID::Size)];
102};
103
105{
106public:
107 using KernelExecuteFunc = void(const circle::Operator *, BaseRuntimeGraph *);
108
109 constexpr KernelExecuteRegistry() : _operator_execute()
110 {
111#define REGISTER_KERNEL(builtin_operator, name) \
112 register_kernel_execute(BuilderID::BuiltinOperator_##builtin_operator, \
113 execute_kernel_Circle##name);
114
115#if USE_GENERATED_LIST
116#include "GeneratedKernelsToBuild.lst"
117#else
118#include "KernelsToBuild.lst"
119#endif
120
121#undef REGISTER_KERNEL
122 }
123
124 void execute_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode,
125 BaseRuntimeGraph *runtime_graph) const;
126
127private:
128 constexpr KernelExecuteFunc *get_kernel_execute_func(circle::BuiltinOperator opcode) const
129 {
130 const auto tmp = size_t(get_builder_id(opcode));
131 assert(tmp < size_t(BuilderID::Size));
132 return _operator_execute[tmp];
133 }
134
135 constexpr void register_kernel_execute(BuilderID id, KernelExecuteFunc *func)
136 {
137 assert(size_t(id) < size_t(BuilderID::Size));
138 _operator_execute[size_t(id)] = func;
139 }
140
141private:
142 KernelExecuteFunc *_operator_execute[size_t(BuilderID::Size)];
143};
144
145#ifdef ENABLE_TRAINING
146
147namespace training
148{
149class KernelTrainRegistry
150{
151public:
152 using KernelTrainFunc = Status(const circle::Operator *, CircleReader *,
153 GradientCalculationStorage *, const TrainingSettings &,
154 TrainableWeightStorage *, bool);
155
156 constexpr KernelTrainRegistry() : _operator_train()
157 {
158#define REGISTER_TRAIN_KERNEL(builtin_operator, name) \
159 register_kernel_train(BuilderID::BuiltinOperator_##builtin_operator, train_kernel_Circle##name);
160
161#if USE_GENERATED_LIST
162#include "GeneratedKernelsToBuild.lst"
163#else
164#include "KernelsToTrain.lst"
165#endif
166
167#undef REGISTER_TRAIN_KERNEL
168 }
169
170 Status train_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode,
171 CircleReader *reader,
172 GradientCalculationStorage *gradient_calculation_storage,
173 const TrainingSettings &settings, TrainableWeightStorage *weight_storage,
174 bool is_compute_gradient) const;
175
176private:
177 constexpr KernelTrainFunc *get_kernel_train_func(circle::BuiltinOperator opcode) const
178 {
179 const auto tmp = size_t(get_builder_id(opcode));
180 assert(tmp < size_t(BuilderID::Size));
181 return _operator_train[tmp];
182 }
183
184 constexpr void register_kernel_train(BuilderID id, KernelTrainFunc *func)
185 {
186 assert(size_t(id) < size_t(BuilderID::Size));
187 _operator_train[size_t(id)] = func;
188 }
189
190private:
191 KernelTrainFunc *_operator_train[size_t(BuilderID::Size)];
192};
193
194constexpr KernelTrainRegistry kernel_train;
195} // namespace training
196#endif // ENABLE_TRAINING
197
198// Global constexpr kernel configure and kernel executor
201
202} // namespace luci_interpreter
203
204#endif // LUCI_INTERPRETER_KERNEL_KERNELBUILDER_H
void(const circle::Operator *, BaseRuntimeGraph *) KernelConfigureFunc
void configure_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode, BaseRuntimeGraph *runtime_graph) const
void(const circle::Operator *, BaseRuntimeGraph *) KernelExecuteFunc
void execute_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode, BaseRuntimeGraph *runtime_graph) const
constexpr BuilderID get_builder_id(circle::BuiltinOperator opcode)
constexpr KernelConfigureRegistry kernel_configure
constexpr KernelExecuteRegistry kernel_executor
OMStatus(const OMBackpropExecuteArgs &) KernelTrainFunc