ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::tf::COpCallGraphBuilder Class Referencefinal

GraphBuilder for COpCall node. More...

#include <COpCall.h>

Collaboration diagram for moco::tf::COpCallGraphBuilder:

Public Member Functions

 COpCallGraphBuilder (const ModelSignature *signature)
 
bool validate (const tensorflow::NodeDef &) const override
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const override
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

GraphBuilder for COpCall node.

Definition at line 32 of file COpCall.h.

Constructor & Destructor Documentation

◆ COpCallGraphBuilder()

moco::tf::COpCallGraphBuilder::COpCallGraphBuilder ( const ModelSignature signature)
inline

Definition at line 35 of file COpCall.h.

35 : _signature(signature)
36 { /* empty */
37 }

Member Function Documentation

◆ build()

void moco::tf::COpCallGraphBuilder::build ( const tensorflow::NodeDef &  tf_node,
GraphBuilderContext context 
) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 69 of file COpCall.cpp.

71{
72 assert(context != nullptr);
73
74 loco::Graph *graph = context->graph();
75 SymbolTable *tensor_names = context->tensor_names();
76 UpdateQueue *updates = context->updates();
77
78 // Create a "COpCall" node for CustomOp and set attributes
79 auto call_node = graph->nodes()->create<locoex::COpCall>(tf_node.input_size());
80 {
81 call_node->op(tf_node.op());
82 call_node->name(tf_node.name());
83 call_node->dtype(_signature->dtype(tf_node.name()));
84
85 auto shape = _signature->shape(tf_node.name());
86 call_node->rank(shape->rank());
87 for (int d = 0; d < shape->rank(); d++)
88 call_node->dim(d) = shape->dim(d);
89
90 for (auto iter = tf_node.attr().begin(); iter != tf_node.attr().end(); iter++)
91 {
92 auto name = iter->first;
93 auto val = iter->second;
94
95 if (val.value_case() == tensorflow::AttrValue::kF)
96 {
97 call_node->attr(name, std::make_unique<locoex::COpAttrFloat>(val.f()));
98 }
99 else if (val.value_case() == tensorflow::AttrValue::kI)
100 {
101 call_node->attr(name, std::make_unique<locoex::COpAttrInt>(val.i()));
102 }
103 // TODO define more types
104 else
105 {
106 throw oops::UserExn("Unsupported attribute type", tf_node.name());
107 }
108 }
109 }
110
111 // register this node with its name
112 TensorName output_name(tf_node.name(), 0);
113 tensor_names->enroll(output_name, call_node);
114
115 // Queue node input update
116 std::vector<TensorName> input_names;
117 for (int i = 0; i < tf_node.input_size(); ++i)
118 {
119 input_names.emplace_back(TensorName(tf_node.input(i)));
120 }
121 auto update = std::make_unique<COpCallGraphUpdate>(call_node, input_names);
122 updates->enroll(std::move(update));
123}
A neural network graph.
Definition Graph.h:161
Class to calls custom operation.
Definition COpCall.h:38
void op(const std::string &op)
Definition COpCall.h:43
Exception to user.
Definition UserExn.h:42
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
std::string TensorName
void dtype(const std::string &node_name, loco::DataType dtype)
Adds node name and its dtype provided from user.
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.

References moco::ModelSignature::dtype(), moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), moco::GraphBuilderContext::graph(), locoex::COpCall::op(), moco::ModelSignature::shape(), moco::GraphBuilderContext::tensor_names(), moco::update(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::tf::COpCallGraphBuilder::validate ( const tensorflow::NodeDef &  tf_node) const
overridevirtual

Implements moco::GraphBuilder.

Definition at line 67 of file COpCall.cpp.

67{ return true; }

The documentation for this class was generated from the following files: