39 COpCallGraphUpdate(
locoex::COpCall *node,
const std::vector<moco::TensorName> &input_names)
40 : _node(node), _input_names(input_names)
48 const std::vector<moco::TensorName> _input_names;
53 for (
int n = 0; n < _input_names.size(); n++)
56 _node->input(n, target);
72 assert(context !=
nullptr);
79 auto call_node = graph->nodes()->create<
locoex::COpCall>(tf_node.input_size());
81 call_node->
op(tf_node.op());
82 call_node->name(tf_node.name());
83 call_node->dtype(_signature->
dtype(tf_node.name()));
85 auto shape = _signature->
shape(tf_node.name());
86 call_node->rank(shape->rank());
87 for (
int d = 0; d < shape->rank(); d++)
88 call_node->dim(d) = shape->dim(d);
90 for (
auto iter = tf_node.attr().begin(); iter != tf_node.attr().end(); iter++)
92 auto name = iter->first;
93 auto val = iter->second;
95 if (val.value_case() == tensorflow::AttrValue::kF)
97 call_node->attr(name, std::make_unique<locoex::COpAttrFloat>(val.f()));
99 else if (val.value_case() == tensorflow::AttrValue::kI)
101 call_node->attr(name, std::make_unique<locoex::COpAttrInt>(val.i()));
106 throw oops::UserExn(
"Unsupported attribute type", tf_node.name());
113 tensor_names->
enroll(output_name, call_node);
116 std::vector<TensorName> input_names;
117 for (
int i = 0; i < tf_node.input_size(); ++i)
119 input_names.emplace_back(
TensorName(tf_node.input(i)));
121 auto update = std::make_unique<COpCallGraphUpdate>(call_node, input_names);
Logical unit of computation.
Class to calls custom operation.
void op(const std::string &op)
Class to store context to build loco graph IR from TensorFlow.
SymbolTable * tensor_names()
Interface to connect the graph.
virtual void input(const SymbolTable *) const =0
Do the graph input connections using the SymbolTable.
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
loco::Node * node(const TensorName &tensor_name) const
Queries enrolled(registered) with name and return node if found Will throw runtime_error if not found...
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
bool validate(const tensorflow::NodeDef &) const override
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
Option< std::string > target(optname("--target"), overview("select target language to emit for given architecture." "Valid values are '" NNC_TARGET_ARM_CPP "', '" NNC_TARGET_X86_CPP "', '" NNC_TARGET_ARM_GPU_CPP "', '" NNC_TARGET_INTERPRETER "'"), std::string(), optional(false), optvalues(NNC_TARGET_ARM_CPP "," NNC_TARGET_X86_CPP "," NNC_TARGET_ARM_GPU_CPP "," NNC_TARGET_INTERPRETER), nullptr, separators("="))
void dtype(const std::string &node_name, loco::DataType dtype)
Adds node name and its dtype provided from user.
void shape(const std::string &node_name, const angkor::TensorShape &shape)
Adds node name and its shape provided from user.