36class TFStridedSliceGraphUpdate final :
public GraphUpdate
39 TFStridedSliceGraphUpdate(
TFStridedSlice *node, std::vector<TensorName> names)
40 : _node(node), _names(names)
48 std::vector<TensorName> _names;
51void TFStridedSliceGraphUpdate::input(
const SymbolTable *node_table)
const
54 assert(_names.size() == 4);
57 auto begin_node = node_table->
node(_names[1]);
58 auto end_node = node_table->
node(_names[2]);
59 auto strides_node = node_table->
node(_names[3]);
60 assert(input_node !=
nullptr);
61 assert(begin_node !=
nullptr);
62 assert(end_node !=
nullptr);
63 assert(strides_node !=
nullptr);
65 _node->input(input_node);
66 _node->begin(begin_node);
68 _node->strides(strides_node);
74 if (_node->begin_mask() != 0 || _node->end_mask() != 0 || _node->ellipsis_mask() != 0 ||
75 _node->new_axis_mask() != 0 || _node->shrink_axis_mask() != 1)
77 throw oops::UserExn(
"Mask attributes are not supported for now: ", _node->name());
81 auto const_input =
dynamic_cast<moco::TFConst *
>(_node->input());
82 auto const_begin =
dynamic_cast<moco::TFConst *
>(_node->begin());
84 auto const_strides =
dynamic_cast<moco::TFConst *
>(_node->strides());
85 if (const_input ==
nullptr || const_begin ==
nullptr || const_end ==
nullptr ||
86 const_strides ==
nullptr)
88 throw oops::UserExn(
"Only Const inputs are supported for now: ", _node->name());
92 if (const_begin->dtype() != loco::DataType::S32 || const_end->dtype() != loco::DataType::S32 ||
93 const_strides->dtype() != loco::DataType::S32)
95 throw oops::UserExn(
"Only Const types of INT32 are supported for begin/end/strides for now: ",
100 auto rin = const_input->rank();
101 if (rin != const_begin->size<loco::DataType::S32>() ||
102 rin != const_end->size<loco::DataType::S32>() ||
103 rin != const_strides->size<loco::DataType::S32>())
105 throw oops::UserExn(
"Ranks for inputs should be same: ", _node->name());
111 uint32_t elements = const_strides->size<loco::DataType::S32>();
112 for (uint32_t e = 0; e < elements; ++e)
114 if (const_strides->at<loco::DataType::S32>(e) != 1)
116 throw oops::UserExn(
"Only stride 1 is supported for now: ", _node->name());
129 if (node.input_size() != 4)
133 "new_axis_mask",
"shrink_axis_mask"}))
142 assert(context !=
nullptr);
148 std::string node_name = node.name();
151 stridedslice->
name(node_name);
160 stridedslice->begin_mask(begin_mask);
161 stridedslice->end_mask(end_mask);
162 stridedslice->ellipsis_mask(ellipsis_mask);
163 stridedslice->new_axis_mask(new_axis_mask);
164 stridedslice->shrink_axis_mask(shrink_axis_mask);
167 assert(begin_mask == 0);
168 assert(end_mask == 0);
169 assert(ellipsis_mask == 0);
170 assert(new_axis_mask == 0);
171 assert(shrink_axis_mask == 1);
175 tensor_names->
enroll(output_name, stridedslice);
177 std::vector<TensorName> input_names;
178 input_names.push_back(
TensorName(node.input(0)));
179 input_names.push_back(
TensorName(node.input(1)));
180 input_names.push_back(
TensorName(node.input(2)));
181 input_names.push_back(
TensorName(node.input(3)));
183 auto tfconv2d_update = std::make_unique<TFStridedSliceGraphUpdate>(stridedslice, input_names);
185 updates->
enroll(std::move(tfconv2d_update));
Class to store context to build loco graph IR from TensorFlow.
SymbolTable * tensor_names()
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
bool validate(const tensorflow::NodeDef &) const final
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
NodeName name(void) const