ONE - On-device Neural Engine
Loading...
Searching...
No Matches
KernelBuilder.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
19
22
23#include <stdexcept>
24
25namespace
26{
27
28// TODO Extract this helper function
29const std::string toString(luci::CircleOpcode opcode)
30{
31 static const char *names[] = {
32#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) #CIRCLE_CLASS,
33#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) #CIRCLE_CLASS,
34#include <luci/IR/CircleNodes.lst>
35#undef CIRCLE_NODE
36#undef CIRCLE_VNODE
37 };
38
39 auto const node_name = names[static_cast<int>(opcode)];
40
41 assert(std::string(node_name).substr(0, 6) == "Circle"); // FIX_ME_UNLESS
42
43 // Return substring of class name ("Circle" is sliced out)
44 // Ex: Return "Conv2D" for "CircleConv2D" node
45 return std::string(node_name).substr(6);
46}
47
48} // namespace
49
50namespace luci_interpreter
51{
52
53#define CIRCLE_NODE(OPCODE, CLASS) CLASS,
54#define CIRCLE_VNODE(OPCODE, CLASS) CLASS,
55
56// This enum is auxiliary.
57// It is duplicate of luci::CircleOpcode but initialized with CLASS instead of OPCODE,
58// because list of target operators is in format of CLASS names
59enum class BuilderId
60{
61#include <luci/IR/CircleNodes.lst>
62 Size // casts to count of values in BuilderId enum
63};
64
65#undef CIRCLE_VNODE
66#undef CIRCLE_NODE
67
75{
76public:
77 using KernelBuilderFunc = std::unique_ptr<Kernel>(const luci::CircleNode *,
79
80 KernelBuilderRegistry() : _operator_builders(size_t(BuilderId::Size), nullptr)
81 {
82#define REGISTER_KERNEL(name) \
83 register_kernel_builder(BuilderId::Circle##name, build_kernel_Circle##name);
84
85#include "KernelsToBuild.lst"
86
87#undef REGISTER_KERNEL
88 }
89
91 {
92 return _operator_builders.at(size_t(opcode));
93 }
94
95private:
96 std::vector<KernelBuilderFunc *> _operator_builders;
97
98 void register_kernel_builder(BuilderId id, KernelBuilderFunc *func)
99 {
100 // Using BuilderId is a duplicate of luci::CirclreOpcode,
101 // size_t(id) is equal to size_t(corresponding operation opcode).
102 assert(size_t(id) < _operator_builders.size());
103 _operator_builders[size_t(id)] = func;
104 }
105};
106
108 const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
109 const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
110 : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
111{
112 _builder_registry = std::make_unique<KernelBuilderRegistry>();
113}
114
116{
117 // Need to define in this CPP to hide KernelBuilderRegistry internals.
118 // This destructor deletes _builder_registry
119}
120
121std::unique_ptr<Kernel> KernelBuilder::build(const luci::CircleNode *node)
122{
123 auto specific_builder = _builder_registry->get_kernel_builder_func(node->opcode());
124 if (specific_builder != nullptr)
125 return specific_builder(node, *this);
126
127 std::string msg = "Unsupported operator: ";
128 msg += toString(node->opcode()) + " in " + std::string(node->name());
129 throw std::invalid_argument(msg.c_str());
130}
131
132} // namespace luci_interpreter
KernelBuilder(const std::unordered_map< const loco::Graph *, RuntimeGraph * > &graph_to_runtime_graph, const std::unordered_map< const loco::Node *, Tensor * > &node_to_tensor)
std::unique_ptr< Kernel > build(const luci::CircleNode *node)
std::unique_ptr< Kernel >(const luci::CircleNode *, KernelBuilderHelper &) KernelBuilderFunc
KernelBuilderFunc * get_kernel_builder_func(luci::CircleOpcode opcode) const
const std::string toString(luci::CircleOpcode opcode)
NodeName name(void) const
virtual CircleOpcode opcode(void) const =0