ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::Conv2DGraphBuilder Class Reference

#include <Conv2D.h>

Collaboration diagram for moco::Conv2DGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file Conv2D.h.

Member Function Documentation

◆ build()

void moco::Conv2DGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 98 of file Conv2D.cpp.

99{
100 assert(context != nullptr);
101
102 loco::Graph *graph = context->graph();
103 SymbolTable *tensor_names = context->tensor_names();
104 UpdateQueue *updates = context->updates();
105
106 // name of loco nodes
107 std::string conv2d_name = node.name();
108
109 auto conv2d = graph->nodes()->create<TFConv2D>();
110 conv2d->name(node.name());
111
112 // read attributes
113 auto data_layout = plier::tf::get_string_attr(node, "data_format");
114 assert(data_layout == "NHWC" || data_layout == "NCHW");
115 conv2d->data_layout(data_layout);
116
117 auto tf_strides = plier::tf::get_list_attr(node, "strides");
118 auto strides = plier::tf::as_int64_list(tf_strides);
119 conv2d->strides(strides);
120
121 auto padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
122 assert(padding == "VALID" || padding == "SAME");
123 conv2d->padding(padding);
124
125 // save the name for graph link updates
126 TensorName output_name(conv2d_name, 0);
127 tensor_names->enroll(output_name, conv2d);
128
129 std::vector<TensorName> input_names;
130 input_names.push_back(TensorName(node.input(0))); // input
131 input_names.push_back(TensorName(node.input(1))); // kernel
132
133 // Record ifm inputs to featureEncode_node
134 auto tfconv2d_update = std::make_unique<TFConv2DGraphUpdate>(conv2d, input_names);
135
136 updates->enroll(std::move(tfconv2d_update));
137}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
std::string str_toupper(std::string s)
Definition Convert.cpp:27
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_int64_list(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_list_attr(), plier::tf::get_string_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::str_toupper(), moco::GraphBuilderContext::tensor_names(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::Conv2DGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 69 of file Conv2D.cpp.

70{
71 if (node.input_size() != 2)
72 return false;
73
74 // note: even though "data_format" is not entered when a model is written,
75 // TF seems to generate "data_format" field into a pb file
76 if (!plier::tf::has_attrs(node, {"T", "data_format", "padding", "strides"}))
77 return false;
78
79 auto data_layout = plier::tf::get_string_attr(node, "data_format");
80 if (!(data_layout == "NHWC" || data_layout == "NCHW"))
81 {
82 throw oops::UserExn("Conv2D Unsupported data_format", node.name());
83 }
84
85 // dilation attribute is not fully supported
86 if (plier::tf::has_attr(node, "dilations"))
87 {
88 // TODO Support non-default dilations
89 auto dilation = plier::tf::get_list_attr(node, "dilations").i();
90 if (!std::all_of(dilation.begin(), dilation.end(), [](std::int64_t dil) { return dil == 1; }))
91 return false;
92 }
93 // Else, dilations are automatically set to default [1,1,1,1] which we assumes now
94
95 return true;
96}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30

References plier::tf::get_list_attr(), plier::tf::get_string_attr(), plier::tf::has_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: