ONE - On-device Neural Engine
Loading...
Searching...
No Matches
moco::BiasAddGraphBuilder Class Reference

#include <BiasAdd.h>

Collaboration diagram for moco::BiasAddGraphBuilder:

Public Member Functions

bool validate (const tensorflow::NodeDef &) const final
 
void build (const tensorflow::NodeDef &, GraphBuilderContext *) const final
 
- Public Member Functions inherited from moco::GraphBuilder
virtual ~GraphBuilder ()
 

Detailed Description

Definition at line 25 of file BiasAdd.h.

Member Function Documentation

◆ build()

void moco::BiasAddGraphBuilder::build ( const tensorflow::NodeDef &  node,
GraphBuilderContext context 
) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 96 of file BiasAdd.cpp.

97{
98 assert(context != nullptr);
99
100 loco::Graph *graph = context->graph();
101 SymbolTable *tensor_names = context->tensor_names();
102 UpdateQueue *updates = context->updates();
103
104 // tensorflow data_format: one of NHWC or NCHW.
105 auto data_layout = plier::tf::get_string_attr(node, "data_format");
106 auto tf_bias_add = graph->nodes()->create<TFBiasAdd>();
107 tf_bias_add->name(node.name());
108 tf_bias_add->data_layout(data_layout);
109
110 // To set the input node of encode_node with biasAdd_name
111 TensorName output_name(node.name(), 0);
112 tensor_names->enroll(output_name, tf_bias_add);
113
114 std::vector<TensorName> input_names;
115 input_names.push_back(TensorName(node.input(0)));
116 input_names.push_back(TensorName(node.input(1)));
117
118 auto update = std::make_unique<TFBiasAddGraphUpdate>(tf_bias_add, input_names);
119 updates->enroll(std::move(update));
120}
A neural network graph.
Definition Graph.h:161
Class to store and query loco::Node* with string name key.
void enroll(const TensorName &tensor_name, loco::Node *node)
Registers a name with corresponding loco::Node *.
Class to store GraphUpdate objects.
void enroll(std::unique_ptr< GraphUpdate > &&update)
Registers GraphUpdate objects.
FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
const std::string & get_string_attr(const tensorflow::NodeDef &node, const std::string &attr_name)
Definition Convert.cpp:79
NodeName name(void) const
Definition TFNodeDecl.h:50

References moco::SymbolTable::enroll(), moco::UpdateQueue::enroll(), plier::tf::get_string_attr(), moco::GraphBuilderContext::graph(), moco::TFNode::name(), moco::GraphBuilderContext::tensor_names(), moco::update(), and moco::GraphBuilderContext::updates().

◆ validate()

bool moco::BiasAddGraphBuilder::validate ( const tensorflow::NodeDef &  node) const
finalvirtual

Implements moco::GraphBuilder.

Definition at line 69 of file BiasAdd.cpp.

70{
71 if (node.input_size() != 2)
72 return false;
73
74 // note: even though "data_format" is not entered when a model is written,
75 // TF seems to generate "data_format" field into a pb file
76 if (!plier::tf::has_attrs(node, {"T", "data_format"}))
77 return false;
78
79 // TODO add type check
80 // type of input and bias should be same (except using quantization)
81
82 // Note In case of TF.nn.bias_add,
83 // "value may have any number of dimensions." ...
84 // but "data_format: A string. 'NHWC' and 'NCHW' are supported."
85 // Not sure if value should be 4-D tensor. Let's skip this check for now.
86
87 auto data_layout = plier::tf::get_string_attr(node, "data_format");
88 if (!(data_layout == "NHWC" || data_layout == "NCHW"))
89 {
90 throw oops::UserExn("BiasAdd Unsupported data_format", node.name());
91 }
92
93 return true;
94}
Exception to user.
Definition UserExn.h:42
bool has_attrs(const tensorflow::NodeDef &node, const std::vector< std::string > &attr_names)
Definition Convert.cpp:35

References plier::tf::get_string_attr(), and plier::tf::has_attrs().


The documentation for this class was generated from the following files: