ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::ConcatV2GraphBuilder Class Reference

#include <Concat.h>

Collaboration diagram for moco::ConcatV2GraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file Concat.h.

Member Function Documentation

◆ build()

void moco::ConcatV2GraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 80 of file Concat.cpp.

82{
83 assert(context != nullptr);
84
85 auto graph = context->graph();
86 auto tensor_names = context->tensor_names();
87 auto updates = context->updates();
88
89 const int num_inputs = node.input_size() - 1;
90 std::vector<TensorName> input_names;
91 auto concat_node = graph->nodes()->create<TFConcatV2>(num_inputs);
92 concat_node->name(node.name());
93
94 for (int ni = 0; ni < num_inputs; ++ni)
95 {
96 input_names.push_back(TensorName(node.input(ni)));
97 }
98 // last one is the axis
99 input_names.push_back(TensorName(node.input(num_inputs)));
100
101 // register string-name to the last node as output of concat(s)
102 TensorName output_name(node.name(), 0);
103 tensor_names->enroll(output_name, concat_node);
104
105 auto update = std::make_unique<TFConcatV2GraphUpdate>(concat_node, input_names);
106 updates->enroll(std::move(update));
107}
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
NodeName name(void) const
Definition TFNodeDecl.h:50

References moco::SymbolTable::enroll(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::GraphBuilderContext::tensor_names(), moco::update(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::ConcatV2GraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 70 of file Concat.cpp.

71{
72 if (!plier::tf::has_attrs(node, {"T", "N", "Tidx"}))
73 return false;
74
75 // Concat node SHOULD have 3 or more inputs, that is 2 + axis
76 const int num_inputs = node.input_size() - 1;
77 return (num_inputs >= 2) && (num_inputs == plier::tf::get_int_attr(node, "N"));
78}
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35
int64_t get_int_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:87

References plier::tf::get_int_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: