ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::AvgPoolGraphBuilder Class Reference

#include <AvgPool.h>

Collaboration diagram for moco::AvgPoolGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file AvgPool.h.

Member Function Documentation

◆ build()

void moco::AvgPoolGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 94 of file AvgPool.cpp.

95{
96 assert(context != nullptr);
97
98 loco::Graph *graph = context->graph();
99 SymbolTable *tensor_names = context->tensor_names();
100 UpdateQueue *updates = context->updates();
101
102 // name of loco nodes
103 ::std::string avgPool2d_name = node.name();
104
105 // tensorflow data_format: one of NHWC or NCHW.
106 auto data_layout = get_string_attr(node, "data_format");
107 auto avgPool_node = graph->nodes()->create<TFAvgPool>();
108 avgPool_node->name(node.name());
109 avgPool_node->data_layout(data_layout);
110
111 // padding
112 auto padding = moco::str_toupper(get_string_attr(node, "padding"));
113 avgPool_node->padding(padding);
114
115 // ksize
116 auto tf_ksize = get_list_attr(node, "ksize");
117 auto ksize = as_int64_list(tf_ksize);
118 avgPool_node->ksize(ksize);
119
120 // strides
121 auto tf_strides = get_list_attr(node, "strides");
122 auto strides = as_int64_list(tf_strides);
123 avgPool_node->strides(strides);
124
125 // To set the input node of encode_node with avgPool2d_name
126 TensorName output_name(avgPool2d_name, 0);
127 tensor_names->enroll(output_name, avgPool_node);
128
129 // Record ifm inputs to featureEncode_node
130 auto update = std::make_unique<TFAvgPoolGraphUpdate>(avgPool_node, TensorName(node.input(0)));
131
132 updates->enroll(std::move(update));
133}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
std::string str_toupper(std::string s)
Definition Convert.cpp:27
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_int64_list(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_list_attr(), plier::tf::get_string_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::str_toupper(), moco::GraphBuilderContext::tensor_names(), moco::update(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::AvgPoolGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 65 of file AvgPool.cpp.

66{
67 if (node.input_size() != 1)
68 return false;
69
70 // note: even though "data_format" is not entered when a model is written,
71 // TF seems to generate "data_format" field into a pb file
72 if (!plier::tf::has_attrs(node, {"T", "data_format", "ksize", "padding", "strides"}))
73 return false;
74
75 auto tf_ksize = get_list_attr(node, "ksize");
76 auto ksize = as_int64_list(tf_ksize);
77 if (ksize.size() != 4)
78 {
79 // TODO support ksize length for 1 and 2
80 throw oops::UserExn("AvgPool only supports ksize length 4", node.name());
81 }
82
83 auto tf_strides = get_list_attr(node, "strides");
84 auto strides = as_int64_list(tf_strides);
85 if (strides.size() != 4)
86 {
87 // TODO support strides length for 1 and 2
88 throw oops::UserExn("AvgPool only supports strides length 4", node.name());
89 }
90
91 return true;
92}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::as_int64_list(), plier::tf::get_list_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: