ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::Conv2DBackpropInputGraphBuilder Class Referencefinal

GraphBuilder for Conv2DBackpropInput node. More...

#include <Conv2DBackpropInput.h>

Collaboration diagram for moco::Conv2DBackpropInputGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const override
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const override
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

GraphBuilder for Conv2DBackpropInput node.

Definition at line 28 of file Conv2DBackpropInput.h.

Member Function Documentation

◆ build()

void moco::Conv2DBackpropInputGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 99 of file Conv2DBackpropInput.cpp.

101{
102 loco::Graph *graph = context->graph();
103 SymbolTable *tensor_names = context->tensor_names();
104 UpdateQueue *updates = context->updates();
105
106 // name of loco nodes
107 std::string conv2d_backprop_name = node.name();
108
109 auto conv2d_backprop = graph->nodes()->create<TFConv2DBackpropInput>();
110 conv2d_backprop->name(node.name());
111
112 // read attributes
113 auto data_layout = plier::tf::get_string_attr(node, "data_format");
114 assert(data_layout == "NHWC" || data_layout == "NCHW");
115 conv2d_backprop->data_layout(data_layout);
116
117 auto tf_strides = plier::tf::get_list_attr(node, "strides");
118 auto strides = plier::tf::as_int64_list(tf_strides);
119 conv2d_backprop->strides(strides);
120
121 auto padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
122 assert(padding == "VALID" || padding == "SAME");
123 conv2d_backprop->padding(padding);
124
125 // save the name for graph link updates
126 TensorName output_name(conv2d_backprop_name, 0);
127 tensor_names->enroll(output_name, conv2d_backprop);
128
129 std::vector<TensorName> input_names;
130 input_names.push_back(TensorName(node.input(0))); // input_sizes
131 input_names.push_back(TensorName(node.input(1))); // filter
132 input_names.push_back(TensorName(node.input(2))); // out_backprop
133
134 // update
135 auto conv2d_backprop_update =
136 std::make_unique<Conv2DBackpropInputGraphUpdate>(conv2d_backprop, input_names);
137
138 updates->enroll(std::move(conv2d_backprop_update));
139}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
std::string str_toupper(std::string s)
Definition Convert.cpp:27
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_int64_list(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_list_attr(), plier::tf::get_string_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::str_toupper(), moco::GraphBuilderContext::tensor_names(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::Conv2DBackpropInputGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 72 of file Conv2DBackpropInput.cpp.

73{
74 if (node.input_size() != 3)
75 return false;
76
77 if (!plier::tf::has_attrs(node, {"T", "data_format", "padding", "strides"}))
78 return false;
79
80 auto data_layout = plier::tf::get_string_attr(node, "data_format");
81 if (!(data_layout == "NHWC" || data_layout == "NCHW"))
82 {
83 throw oops::UserExn("Conv2DBackprop Unsupported data_format", node.name());
84 }
85
86 // dilation attribute is not fully supported
87 if (plier::tf::has_attr(node, "dilations"))
88 {
89 // TODO Support non-default dilations
90 auto dilation = plier::tf::get_list_attr(node, "dilations").i();
91 if (!std::all_of(dilation.begin(), dilation.end(), [](std::int64_t dil) { return dil == 1; }))
92 return false;
93 }
94 // Else, dilations are automatically set to default [1,1,1,1] which we assumes now
95
96 return true;
97}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
bool has_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:30

References plier::tf::get_list_attr(), plier::tf::get_string_attr(), plier::tf::has_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: