17#ifndef LUCI_INTERPRETER_KERNEL_KERNELBUILDER_H
18#define LUCI_INTERPRETER_KERNEL_KERNELBUILDER_H
20#include "core/RuntimeModule.h"
25#include <unordered_map>
29#define REGISTER_KERNEL(builtin_operator, name) BuiltinOperator_##builtin_operator,
34#include "GeneratedKernelsToBuild.lst"
36#include "KernelsToBuild.lst"
47#define REGISTER_KERNEL(builtin_operator, name) \
48 case circle::BuiltinOperator_##builtin_operator: \
49 return BuilderID::BuiltinOperator_##builtin_operator;
52#include "GeneratedKernelsToBuild.lst"
54#include "KernelsToBuild.lst"
59 assert(
false &&
"Unsupported operation");
70#define REGISTER_KERNEL(builtin_operator, name) \
71 register_kernel_configure(BuilderID::BuiltinOperator_##builtin_operator, \
72 configure_kernel_Circle##name);
75#include "GeneratedKernelsToBuild.lst"
77#include "KernelsToBuild.lst"
83 void configure_kernel(
const circle::Operator *cur_op, circle::BuiltinOperator opcode,
87 constexpr KernelConfigureFunc *get_kernel_configure_func(circle::BuiltinOperator opcode)
const
91 return _operator_configure[builder_id_opcode];
97 _operator_configure[size_t(
id)] = func;
111#define REGISTER_KERNEL(builtin_operator, name) \
112 register_kernel_execute(BuilderID::BuiltinOperator_##builtin_operator, \
113 execute_kernel_Circle##name);
115#if USE_GENERATED_LIST
116#include "GeneratedKernelsToBuild.lst"
118#include "KernelsToBuild.lst"
121#undef REGISTER_KERNEL
124 void execute_kernel(
const circle::Operator *cur_op, circle::BuiltinOperator opcode,
128 constexpr KernelExecuteFunc *get_kernel_execute_func(circle::BuiltinOperator opcode)
const
132 return _operator_execute[tmp];
138 _operator_execute[size_t(
id)] = func;
145#ifdef ENABLE_TRAINING
149class KernelTrainRegistry
152 using KernelTrainFunc = Status(
const circle::Operator *, CircleReader *,
153 GradientCalculationStorage *,
const TrainingSettings &,
154 TrainableWeightStorage *,
bool);
156 constexpr KernelTrainRegistry() : _operator_train()
158#define REGISTER_TRAIN_KERNEL(builtin_operator, name) \
159 register_kernel_train(BuilderID::BuiltinOperator_##builtin_operator, train_kernel_Circle##name);
161#if USE_GENERATED_LIST
162#include "GeneratedKernelsToBuild.lst"
164#include "KernelsToTrain.lst"
167#undef REGISTER_TRAIN_KERNEL
170 Status train_kernel(
const circle::Operator *cur_op, circle::BuiltinOperator opcode,
171 CircleReader *reader,
172 GradientCalculationStorage *gradient_calculation_storage,
173 const TrainingSettings &settings, TrainableWeightStorage *weight_storage,
174 bool is_compute_gradient)
const;
177 constexpr KernelTrainFunc *get_kernel_train_func(circle::BuiltinOperator opcode)
const
180 assert(tmp <
size_t(BuilderID::Size));
181 return _operator_train[tmp];
184 constexpr void register_kernel_train(BuilderID
id, KernelTrainFunc *func)
186 assert(
size_t(
id) <
size_t(BuilderID::Size));
187 _operator_train[size_t(
id)] = func;
194constexpr KernelTrainRegistry kernel_train;
void(const circle::Operator *, BaseRuntimeGraph *) KernelExecuteFunc
void execute_kernel(const circle::Operator *cur_op, circle::BuiltinOperator opcode, BaseRuntimeGraph *runtime_graph) const
constexpr KernelExecuteRegistry()
constexpr BuilderID get_builder_id(circle::BuiltinOperator opcode)
constexpr KernelConfigureRegistry kernel_configure
constexpr KernelExecuteRegistry kernel_executor
OMStatus(const OMBackpropExecuteArgs &) KernelTrainFunc