ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Pad.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
19#include <moco/IR/Nodes/TFPad.h>
20
21#include <loco.h>
22#include <plier/tf/Convert.h>
23
24#include <memory>
25
26namespace
27{
28
29using namespace moco;
30
34class TFPadGraphUpdate final : public GraphUpdate
35{
36public:
37 TFPadGraphUpdate(TFPad *node, std::vector<TensorName> names) : _node(node), _names(names) {}
38
39 void input(const SymbolTable *) const override;
40
41private:
42 TFPad *_node;
43 std::vector<TensorName> _names;
44};
45
46void TFPadGraphUpdate::input(const SymbolTable *table) const
47{
48 assert(_names.size() == 2);
49
50 _node->input(table->node(_names[0]));
51 _node->paddings(table->node(_names[1]));
52}
53
54} // namespace
55
56namespace moco
57{
58
59bool PadGraphBuilder::validate(const tensorflow::NodeDef &node) const
60{
61 if (node.input_size() != 2)
62 return false;
63
64 return plier::tf::has_attrs(node, {"T", "Tpaddings"});
65}
66
67void PadGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
68{
69 assert(context != nullptr);
70
71 loco::Graph *graph = context->graph();
72 SymbolTable *tensor_names = context->tensor_names();
73 UpdateQueue *updates = context->updates();
74
75 // creating TF dialect Pad node
76 auto tf_pad = graph->nodes()->create<TFPad>();
77 tf_pad->name(node.name());
78
79 // register string-name to node
80 TensorName output_name(node.name(), 0);
81 tensor_names->enroll(output_name, tf_pad);
82
83 std::vector<TensorName> add_input_names;
84 add_input_names.push_back(TensorName(node.input(0))); // input
85 add_input_names.push_back(TensorName(node.input(1))); // paddings
86
87 // Queue node input update
88 auto tf_pad_update = std::make_unique<TFPadGraphUpdate>(tf_pad, add_input_names);
89 updates->enroll(std::move(tf_pad_update));
90}
91
92} // namespace moco
A neural network graph.
Definition Graph.h:161
Class to store context to build loco graph IR from TensorFlow.
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override
Definition Pad.cpp:67
bool validate(const tensorflow::NodeDef &) const override
Definition Pad.cpp:59
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
Definition Log.h:23
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
NodeName name(void) const
Definition TFNodeDecl.h:50