ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Reshape.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
18
20
21#include <moco/Names.h>
22#include <plier/tf/Convert.h>
23#include <loco.h>
24
25#include <memory>
26#include <cassert>
27#include <stdexcept>
28
29namespace
30{
31using namespace moco;
32
33class ReshapeGraphUpdate final : public GraphUpdate
34{
35public:
36 ReshapeGraphUpdate(TFReshape *node, std::vector<TensorName> names) : _node(node), _names(names) {}
37
38 void input(const SymbolTable *) const override;
39
40private:
41 TFReshape *_node;
42 std::vector<TensorName> _names;
43};
44
45void ReshapeGraphUpdate::input(const SymbolTable *node_table) const
46{
47 assert(_names.size() == 2);
48
49 auto tensor_node = node_table->node(_names[0]);
50 auto shape_node = node_table->node(_names[1]);
51
52 assert(tensor_node != nullptr);
53 assert(shape_node != nullptr);
54
55 _node->tensor(tensor_node);
56 _node->shape(shape_node);
57}
58
59} // namespace
60
61namespace moco
62{
63
64bool ReshapeGraphBuilder::validate(const tensorflow::NodeDef &node) const
65{
66 // Tensorflow Reshape has 2 inputs: tensor & shape
67 if (node.input_size() != 2)
68 return false;
69
70 // TODO Assert Tshape value is DT_INT32?
71 return plier::tf::has_attrs(node, {"T", "Tshape"});
72}
73
74void ReshapeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
75{
76 assert(context != nullptr);
77
78 loco::Graph *graph = context->graph();
79 SymbolTable *tensor_names = context->tensor_names();
80 UpdateQueue *updates = context->updates();
81
82 // name of loco nodes
83 std::string reshape_name = node.name();
84
85 auto reshape = graph->nodes()->create<TFReshape>();
86 reshape->name(node.name());
87
88 // save the name for graph link updates
89 TensorName output_name(reshape_name, 0);
90 tensor_names->enroll(output_name, reshape);
91
92 std::vector<TensorName> input_names;
93 input_names.push_back(TensorName(node.input(0))); // tensor
94 input_names.push_back(TensorName(node.input(1))); // shape
95
96 // Queue node input update
97 auto update = std::make_unique<ReshapeGraphUpdate>(reshape, input_names);
98
99 updates->enroll(std::move(update));
100}
101
102} // namespace moco
A neural network graph.
Definition Graph.h:161
Class to store context to build loco graph IR from TensorFlow.
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override
Definition Reshape.cpp:74
bool validate(const tensorflow::NodeDef &) const override
Definition Reshape.cpp:64
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
Definition Log.h:23
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35