ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GraphBuilderRegistry.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __GRAPH_BUILDER_REGISTRY_H__
18#define __GRAPH_BUILDER_REGISTRY_H__
19
20#include "Op/Conv2D.h"
21#include "Op/DepthwiseConv2D.h"
22#include "Op/AveragePool2D.h"
23#include "Op/MaxPool2D.h"
24#include "Op/Concatenation.h"
25#include "Op/ReLU.h"
26#include "Op/ReLU6.h"
27#include "Op/Reshape.h"
28#include "Op/Sub.h"
29#include "Op/Div.h"
30
31#include <schema_generated.h>
32
33#include <memory>
34#include <map>
35
36using std::make_unique;
37
38namespace tflimport
39{
40
45{
46public:
51 const GraphBuilder *lookup(tflite::BuiltinOperator op) const
52 {
53 if (_builder_map.find(op) == _builder_map.end())
54 return nullptr;
55
56 return _builder_map.at(op).get();
57 }
58
60 {
61 static GraphBuilderRegistry me;
62 return me;
63 }
64
65private:
67 {
68 // add GraphBuilder for each tflite operation.
69 _builder_map[tflite::BuiltinOperator_CONV_2D] = make_unique<Conv2DGraphBuilder>();
70 _builder_map[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] =
71 make_unique<DepthwiseConv2DGraphBuilder>();
72 _builder_map[tflite::BuiltinOperator_AVERAGE_POOL_2D] = make_unique<AvgPool2DGraphBuilder>();
73 _builder_map[tflite::BuiltinOperator_MAX_POOL_2D] = make_unique<MaxPool2DGraphBuilder>();
74 _builder_map[tflite::BuiltinOperator_CONCATENATION] = make_unique<ConcatenationGraphBuilder>();
75 _builder_map[tflite::BuiltinOperator_RELU] = make_unique<ReLUGraphBuilder>();
76 _builder_map[tflite::BuiltinOperator_RELU6] = make_unique<ReLU6GraphBuilder>();
77 _builder_map[tflite::BuiltinOperator_RESHAPE] = make_unique<ReshapeGraphBuilder>();
78 _builder_map[tflite::BuiltinOperator_SUB] = make_unique<SubGraphBuilder>();
79 _builder_map[tflite::BuiltinOperator_DIV] = make_unique<DivGraphBuilder>();
80 }
81
82private:
83 std::map<tflite::BuiltinOperator, std::unique_ptr<GraphBuilder>> _builder_map;
84};
85
86} // namespace tflimport
87
88#endif // __GRAPH_BUILDER_REGISTRY_H__
Parent class of tflite operation graph builders (e.g., Conv2DGraphBuilder)
Class to return graph builder for passed tflite::builtinOperator.
const GraphBuilder * lookup(tflite::BuiltinOperator op) const
Returns registered GraphBuilder pointer for BuiltinOperator or nullptr if not registered.
static GraphBuilderRegistry & get()