ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::PlaceholderGraphBuilder Class Referencefinal

GraphBuilder for Placeholder node. More...

#include <Placeholder.h>

Collaboration diagram for moco::PlaceholderGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const override
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const override
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

GraphBuilder for Placeholder node.

Definition at line 28 of file Placeholder.h.

Member Function Documentation

◆ build()

void moco::PlaceholderGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 43 of file Placeholder.cpp.

45{
46 assert(context != nullptr);
47
48 loco::Graph *graph = context->graph();
49 SymbolTable *tensor_names = context->tensor_names();
50
52 const auto &shape = plier::tf::get_shape_attr(node, "shape");
53 // TODO handle for unknown rank
54 assert(!shape.unknown_rank());
55 int64_t num_dims = shape.dim_size();
56
57 // TODO support other types
58 assert(dtype == loco::DataType::FLOAT32);
59
60 // Create a "Placeholder" node as an input
61 auto placeholder_node = graph->nodes()->create<moco::TFPlaceholder>();
62 placeholder_node->name(node.name());
63 placeholder_node->dtype(dtype);
64
65 // Setting shape info.
66 placeholder_node->rank(num_dims);
67 for (int64_t d = 0; d < num_dims; d++)
68 {
69 assert(shape.dim(d).size() < std::numeric_limits<uint32_t>::max());
70 int64_t dim_value = shape.dim(d).size();
71 if (dim_value >= 0)
72 {
73 uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
74 placeholder_node->dim(d) = dim_value32;
75 }
76 else
77 {
78 placeholder_node->dim(d).unset();
79 // TODO Remove assert() and do implement
80 // NOTE Current implementation assumes dim is all know
81 assert(false);
82 }
83 }
84
85 // register string-name to node
86 TensorName output_name(node.name(), 0);
87 tensor_names->enroll(output_name, placeholder_node);
88}
A neural network graph.
Definition Graph.h:161
IR for tf.placeholder.
DataType
"scalar" value type
Definition DataType.h:27
TFPlaceholder * placeholder_node(loco::Graph *g, const loco::GraphInputIndex &idx)
Definition TFNode.cpp:84
tensorflow::DataType get_datatype_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:43
loco::DataType as_loco_datatype(const tensorflow::DataType dtype)
Definition Convert.cpp:123
const tensorflow::TensorShapeProto & get_shape_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:52
std::string TensorName
NodeName name(void) const
Definition TFNodeDecl.h:50

References plier::tf::as_loco_datatype(), moco::SymbolTable::enroll(), plier::tf::get_datatype_attr(), plier::tf::get_shape_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::placeholder_node(), and moco::GraphBuilderContext::tensor_names().

◆ validate()

bool moco::PlaceholderGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 30 of file Placeholder.cpp.

31{
32 if (!plier::tf::has_attrs(node, {"dtype", "shape"}))
33 return false;
34
36 if (dtype != loco::DataType::FLOAT32)
37 return false;
38 // TODO support other types
39
40 return true;
41}
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::as_loco_datatype(), plier::tf::get_datatype_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: