ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::DepthwiseConv2dNativeGraphBuilder Class Referencefinal

GraphBuilder for DepthwiseConv2dNative node. More...

#include <DepthwiseConv2dNative.h>

Collaboration diagram for moco::DepthwiseConv2dNativeGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const override
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const override
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

GraphBuilder for DepthwiseConv2dNative node.

Definition at line 28 of file DepthwiseConv2dNative.h.

Member Function Documentation

◆ build()

void moco::DepthwiseConv2dNativeGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 110 of file DepthwiseConv2dNative.cpp.

112{
113 assert(context != nullptr);
114
115 loco::Graph *graph = context->graph();
116 SymbolTable *tensor_names = context->tensor_names();
117 UpdateQueue *updates = context->updates();
118
119 auto depthwiseconv2d_native_node = graph->nodes()->create<TFDepthwiseConv2dNative>();
120 depthwiseconv2d_native_node->name(node.name());
121
122 // read attributes
123 auto data_layout = get_string_attr(node, "data_format");
124 depthwiseconv2d_native_node->data_layout(data_layout);
125
126 auto tf_strides = get_list_attr(node, "strides");
127 auto strides = as_int64_list(tf_strides);
128 depthwiseconv2d_native_node->strides(strides);
129
130 auto padding = moco::str_toupper(get_string_attr(node, "padding"));
131 depthwiseconv2d_native_node->padding(padding);
132
133 // save the name for graph link updates
134 TensorName output_name(node.name(), 0);
135 tensor_names->enroll(output_name, depthwiseconv2d_native_node);
136
137 std::vector<TensorName> input_names;
138 input_names.push_back(TensorName(node.input(0))); // input
139 input_names.push_back(TensorName(node.input(1))); // kernel
140
141 // Record ifm inputs to featureEncode_node
142 auto tfdepthwiseconv2dnative_update =
143 std::make_unique<TFDepthwiseConv2dNativeGraphUpdate>(depthwiseconv2d_native_node, input_names);
144
145 updates->enroll(std::move(tfdepthwiseconv2dnative_update));
146}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
std::string str_toupper(std::string s)
Definition Convert.cpp:27
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_int64_list(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_list_attr(), plier::tf::get_string_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::str_toupper(), moco::GraphBuilderContext::tensor_names(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::DepthwiseConv2dNativeGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 71 of file DepthwiseConv2dNative.cpp.

72{
73 if (node.input_size() != 2)
74 return false;
75
76 // note: even though "data_format" and "dilations" are not entered when a model is written,
77 // TF seems to generate those field into a pb file.
78 if (!has_attrs(node, {"T", "data_format", "dilations", "padding", "strides"}))
79 return false;
80
81 auto data_layout = plier::tf::get_string_attr(node, "data_format");
82 if (!(data_layout == "NHWC" || data_layout == "NCHW"))
83 {
84 throw oops::UserExn("DepthwiseConv2dNative Unsupported data_format", node.name());
85 }
86
87 auto padding = moco::str_toupper(get_string_attr(node, "padding"));
88 if (!(padding == "VALID" || padding == "SAME"))
89 return false;
90
91 auto tf_strides = get_list_attr(node, "strides");
92 auto strides = as_int64_list(tf_strides);
93 if (!(strides.size() == 4))
94 {
95 throw oops::UserExn("DepthwiseConv2dNative strides requires rank 4", node.name());
96 }
97 auto stride_n = strides.at(0);
98 auto stride_h = strides.at(1);
99 auto stride_w = strides.at(2);
100 auto stride_c = strides.at(3);
101 if (!(stride_n == 1 && stride_c == 1) || !(stride_h == stride_w))
102 {
103 // TODO this message may need to be refined
104 throw oops::UserExn("DepthwiseConv2dNative strides requires N=C=1, H=W", node.name());
105 }
106
107 return true;
108}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::as_int64_list(), plier::tf::get_list_attr(), plier::tf::get_string_attr(), plier::tf::has_attrs(), and moco::str_toupper().


The documentation for this class was generated from the following files: