ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Sub.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
19#include <moco/IR/Nodes/TFSub.h>
20
21#include <loco.h>
22
23#include <memory>
24
25namespace
26{
27
28using namespace moco;
29
33class TFSubGraphUpdate final : public GraphUpdate
34{
35public:
36 TFSubGraphUpdate(TFSub *node, std::vector<TensorName> names) : _node(node), _names(names) {}
37
38 void input(const SymbolTable *) const override;
39
40private:
41 TFSub *_node;
42 std::vector<TensorName> _names;
43};
44
45void TFSubGraphUpdate::input(const SymbolTable *tensor_names) const
46{
47 assert(_names.size() == 2);
48
49 _node->x(tensor_names->node(_names[0]));
50 _node->y(tensor_names->node(_names[1]));
51}
52
53} // namespace
54
55namespace moco
56{
57
58bool SubGraphBuilder::validate(const tensorflow::NodeDef &node) const
59{
60 return node.input_size() == 2;
61}
62
63void SubGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
64{
65 assert(context != nullptr);
66
67 loco::Graph *graph = context->graph();
68 SymbolTable *tensor_names = context->tensor_names();
69 UpdateQueue *updates = context->updates();
70
71 // creating TF dialect Sub node
72 auto tf_sub = graph->nodes()->create<TFSub>();
73 tf_sub->name(node.name());
74
75 TensorName output_name(node.name(), 0);
76 tensor_names->enroll(output_name, tf_sub);
77
78 std::vector<TensorName> sub_input_names;
79 sub_input_names.push_back(TensorName(node.input(0))); // x
80 sub_input_names.push_back(TensorName(node.input(1))); // y
81
82 auto tf_sub_update = std::make_unique<TFSubGraphUpdate>(tf_sub, sub_input_names);
83 updates->enroll(std::move(tf_sub_update));
84}
85
86} // namespace moco
A neural network graph.
Definition Graph.h:161
Class to store context to build loco graph IR from TensorFlow.
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override
Definition Sub.cpp:63
bool validate(const tensorflow::NodeDef &) const override
Definition Sub.cpp:58
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
Definition Log.h:23
NodeName name(void) const
Definition TFNodeDecl.h:50