ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::SqueezeGraphBuilder Class Referencefinal

GraphBuilder for Squeeze node. More...

#include <Squeeze.h>

Collaboration diagram for moco::SqueezeGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const override
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const override
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

GraphBuilder for Squeeze node.

Definition at line 28 of file Squeeze.h.

Member Function Documentation

◆ build()

void moco::SqueezeGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 80 of file Squeeze.cpp.

81{
82 assert(context != nullptr);
83
84 loco::Graph *graph = context->graph();
85 SymbolTable *tensor_names = context->tensor_names();
86 UpdateQueue *updates = context->updates();
87
88 // TODO support 'axis' attribute
89 assert(!plier::tf::has_attrs(node, {"axis"}));
90
91 std::vector<int64_t> squeeze_dims;
92 if (plier::tf::has_attrs(node, {"squeeze_dims"}))
93 {
94 auto squeeze_dim_list = plier::tf::get_list_attr(node, {"squeeze_dims"});
95 // TODO assert squeeze_dims are mutually different?
96 squeeze_dims = plier::tf::as_int64_list(squeeze_dim_list);
97 }
98 // Note that it is possible that NodeDef does not have squeeze_dims attribute.
99 // In that case, TFSqueeze also has empty squeeze_dims,
100
101 // creating TF dialect Squeeze node
102 auto tf_squeeze = graph->nodes()->create<TFSqueeze>();
103 tf_squeeze->name(node.name());
104 tf_squeeze->squeeze_dims(squeeze_dims);
105
106 TensorName output_name(node.name(), 0);
107 tensor_names->enroll(output_name, tf_squeeze);
108
109 auto update = std::make_unique<SqueezeGraphUpdate>(tf_squeeze, TensorName(node.input(0)));
110 updates->enroll(std::move(update));
111}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
std::vector< int64_t > as_int64_list(const tensorflow::AttrValue_ListValue &lv)
Definition Convert.cpp:111
const tensorflow::AttrValue_ListValue & get_list_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:70
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_int64_list(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_list_attr(), moco::GraphBuilderContext::graph(), plier::tf::has_attrs(), moco::TFNode::name(), moco::GraphBuilderContext::tensor_names(), moco::update(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::SqueezeGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 63 of file Squeeze.cpp.

64{
65 if (node.input_size() != 1)
66 return false;
67
68 if (!plier::tf::has_attrs(node, {"T"}))
69 return false;
70
71 if (plier::tf::has_attrs(node, {"axis"}))
72 {
73 // TODO support 'axis' attribute
74 oops::UserExn("Squeeze: Unsupported 'axis' attribute", node.name());
75 }
76
77 return true;
78}
Exception to user.
Definition UserExn.h:42

References plier::tf::has_attrs().


The documentation for this class was generated from the following files: