ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::StridedSliceGraphBuilder Class Reference

#include <StridedSlice.h>

Collaboration diagram for moco::StridedSliceGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file StridedSlice.h.

Member Function Documentation

◆ build()

void moco::StridedSliceGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 139 of file StridedSlice.cpp.

141{
142 assert(context != nullptr);
143
144 loco::Graph *graph = context->graph();
145 SymbolTable *tensor_names = context->tensor_names();
146 UpdateQueue *updates = context->updates();
147
148 std::string node_name = node.name();
149
150 auto stridedslice = graph->nodes()->create<TFStridedSlice>();
151 stridedslice->name(node_name);
152
153 // read attributes
154 auto begin_mask = plier::tf::get_int_attr(node, "begin_mask");
155 auto end_mask = plier::tf::get_int_attr(node, "end_mask");
156 auto ellipsis_mask = plier::tf::get_int_attr(node, "ellipsis_mask");
157 auto new_axis_mask = plier::tf::get_int_attr(node, "new_axis_mask");
158 auto shrink_axis_mask = plier::tf::get_int_attr(node, "shrink_axis_mask");
159
160 stridedslice->begin_mask(begin_mask);
161 stridedslice->end_mask(end_mask);
162 stridedslice->ellipsis_mask(ellipsis_mask);
163 stridedslice->new_axis_mask(new_axis_mask);
164 stridedslice->shrink_axis_mask(shrink_axis_mask);
165
166 // TODO support general mask values: we support only this limited case for now
167 assert(begin_mask == 0);
168 assert(end_mask == 0);
169 assert(ellipsis_mask == 0);
170 assert(new_axis_mask == 0);
171 assert(shrink_axis_mask == 1);
172
173 // save the name for graph link updates
174 TensorName output_name(node_name, 0);
175 tensor_names->enroll(output_name, stridedslice);
176
177 std::vector<TensorName> input_names;
178 input_names.push_back(TensorName(node.input(0))); // input
179 input_names.push_back(TensorName(node.input(1))); // begin
180 input_names.push_back(TensorName(node.input(2))); // end
181 input_names.push_back(TensorName(node.input(3))); // strides
182
183 auto tfconv2d_update = std::make_unique<TFStridedSliceGraphUpdate>(stridedslice, input_names);
184
185 updates->enroll(std::move(tfconv2d_update));
186}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:87
NodeName name(void) const
Definition TFNodeDecl.h:50

References moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_int_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::GraphBuilderContext::tensor_names(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::StridedSliceGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 126 of file StridedSlice.cpp.

127{
128 // TODO support node.input_size() == 3 where strides is None
129 if (node.input_size() != 4)
130 return false;
131
132 if (!plier::tf::has_attrs(node, {"T", "Index", "begin_mask", "end_mask", "ellipsis_mask",
133 "new_axis_mask", "shrink_axis_mask"}))
134 return false;
135
136 return true;
137}
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::has_attrs().


The documentation for this class was generated from the following files: