ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::ConstGraphBuilder Class Reference

#include <Const.h>

Collaboration diagram for moco::ConstGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file Const.h.

Member Function Documentation

◆ build()

void moco::ConstGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 172 of file Const.cpp.

173{
174 loco::Graph *graph = context->graph();
175 SymbolTable *tensor_names = context->tensor_names();
176
177 // Create a "TFConstant" node for Const
178 auto const_node = graph->nodes()->create<TFConst>();
179 const_node->name(node.name());
180
181 // set dtype
183 const_node->dtype(dtype);
184
185 // import shape and value
186 const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
187 const auto &input_shape = input_tensor.tensor_shape();
188 const auto &input_dims = input_shape.dim();
189 assert(input_shape.dim_size() <= 6);
190 const_node->rank(input_shape.dim_size());
191 int index = 0;
192 bool zero_sized_shape = false;
193 for (auto &d : input_dims)
194 {
195 assert(d.size() <= std::numeric_limits<int>::max());
196 if (d.size() == 0)
197 zero_sized_shape = true;
198
199 assert(d.size() >= 0);
200 const_node->dim(index++) = d.size();
201 }
202
203 int num_elements = 1;
204 if (zero_sized_shape)
205 {
206 const_node->rank(0);
207 num_elements = 0;
208 }
209 else
210 {
211 for (uint32_t d = 0; d < const_node->rank(); d++)
212 {
213 num_elements *= const_node->dim(d).value();
214 }
215 }
216
217 switch (dtype)
218 {
219 case loco::DataType::S8:
220 read_value_int8(const_node, num_elements, input_tensor);
221 break;
222
223 case loco::DataType::S32:
224 read_value_int32(const_node, num_elements, input_tensor);
225 break;
226
227 case loco::DataType::FLOAT32:
228 read_value_float32(const_node, num_elements, input_tensor);
229 break;
230
231 // TODO support other types
232
233 default:
234 assert(false);
235 }
236
237 // register string-name to node
238 TensorName output_name(node.name(), 0);
239 tensor_names->enroll(output_name, const_node);
240}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
IR for tf.constant.
Definition TFConst.h:67
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
const tensorflow::TensorProto & get_tensor_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:61
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:43
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
Definition Convert.cpp:123
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_loco_datatype(), moco::SymbolTable::enroll(), plier::tf::get_datatype_attr(), plier::tf::get_tensor_attr(), moco::GraphBuilderContext::graph(), moco::index(), moco::TFNode::name(), moco::TFConst::size(), and moco::GraphBuilderContext::tensor_names().

◆ validate()

bool moco::ConstGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 142 of file Const.cpp.

143{
144 if (!plier::tf::has_attrs(node, {"dtype", "value"}))
145 return false;
146
147 const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
148 const auto &input_shape = input_tensor.tensor_shape();
149 const auto &input_dims = input_shape.dim();
150
151 if (!(input_shape.dim_size() <= 6))
152 return false;
153
154 for (auto &d : input_dims)
155 {
156 if (d.size() > std::numeric_limits<int>::max())
157 throw oops::UserExn("Const Shape element overflows", node.name());
158
159 if (d.size() < 0)
160 throw oops::UserExn("Unknown dim size", node.name());
161 }
162
164 if (!(dtype == loco::DataType::S32 || dtype == loco::DataType::FLOAT32 ||
165 dtype == loco::DataType::S8))
166 return false;
167 // TODO support other dtype
168
169 return true;
170}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::as_loco_datatype(), plier::tf::get_datatype_attr(), plier::tf::get_tensor_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: