ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Conv2D.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
20
21#include <moco/Names.h>
22
23#include "Convert.h"
24
25#include <loco.h>
27#include <plier/tf/Convert.h>
28#include <oops/UserExn.h>
29
30#include <memory>
31#include <cassert>
32#include <stdexcept>
33#include <algorithm>
34
35namespace
36{
37using namespace moco;
38
39class TFConv2DGraphUpdate final : public GraphUpdate
40{
41public:
42 TFConv2DGraphUpdate(TFConv2D *node, std::vector<TensorName> names) : _node(node), _names(names) {}
43
44 void input(const SymbolTable *) const override;
45
46private:
47 TFConv2D *_node;
48 std::vector<TensorName> _names;
49};
50
51void TFConv2DGraphUpdate::input(const SymbolTable *node_table) const
52{
53 assert(_names.size() == 2);
54
55 auto input_node = node_table->node(_names[0]);
56 auto filter_node = node_table->node(_names[1]);
57 assert(input_node != nullptr);
58 assert(filter_node != nullptr);
59
60 _node->input(input_node);
61 _node->filter(filter_node);
62}
63
64} // namespace
65
66namespace moco
67{
68
69bool Conv2DGraphBuilder::validate(const tensorflow::NodeDef &node) const
70{
71 if (node.input_size() != 2)
72 return false;
73
74 // note: even though "data_format" is not entered when a model is written,
75 // TF seems to generate "data_format" field into a pb file
76 if (!plier::tf::has_attrs(node, {"T", "data_format", "padding", "strides"}))
77 return false;
78
79 auto data_layout = plier::tf::get_string_attr(node, "data_format");
80 if (!(data_layout == "NHWC" || data_layout == "NCHW"))
81 {
82 throw oops::UserExn("Conv2D Unsupported data_format", node.name());
83 }
84
85 // dilation attribute is not fully supported
86 if (plier::tf::has_attr(node, "dilations"))
87 {
88 // TODO Support non-default dilations
89 auto dilation = plier::tf::get_list_attr(node, "dilations").i();
90 if (!std::all_of(dilation.begin(), dilation.end(), [](std::int64_t dil) { return dil == 1; }))
91 return false;
92 }
93 // Else, dilations are automatically set to default [1,1,1,1] which we assumes now
94
95 return true;
96}
97
98void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
99{
100 assert(context != nullptr);
101
102 loco::Graph *graph = context->graph();
103 SymbolTable *tensor_names = context->tensor_names();
104 UpdateQueue *updates = context->updates();
105
106 // name of loco nodes
107 std::string conv2d_name = node.name();
108
109 auto conv2d = graph->nodes()->create<TFConv2D>();
110 conv2d->name(node.name());
111
112 // read attributes
113 auto data_layout = plier::tf::get_string_attr(node, "data_format");
114 assert(data_layout == "NHWC" || data_layout == "NCHW");
115 conv2d->data_layout(data_layout);
116
117 auto tf_strides = plier::tf::get_list_attr(node, "strides");
118 auto strides = plier::tf::as_int64_list(tf_strides);
119 conv2d->strides(strides);
120
121 auto padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
122 assert(padding == "VALID" || padding == "SAME");
123 conv2d->padding(padding);
124
125 // save the name for graph link updates
126 TensorName output_name(conv2d_name, 0);
127 tensor_names->enroll(output_name, conv2d);
128
129 std::vector<TensorName> input_names;
130 input_names.push_back(TensorName(node.input(0))); // input
131 input_names.push_back(TensorName(node.input(1))); // kernel
132
133 // Record ifm inputs to featureEncode_node
134 auto tfconv2d_update = std::make_unique<TFConv2DGraphUpdate>(conv2d, input_names);
135
136 updates->enroll(std::move(tfconv2d_update));
137}
138
139} // namespace moco
A neural network graph.
Definition Graph.h:161
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final
Definition Conv2D.cpp:98
bool validate(const tensorflow::NodeDef &) const final
Definition Conv2D.cpp:69
Class to store context to build loco graph IR from TensorFlow.
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
Exception to user.
Definition UserExn.h:42
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
Definition Log.h:23
std::string str_toupper(std::string s)
Definition Convert.cpp:27
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50