ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::ir::train::UseDefGenerator Class Reference

#include <UseDefGenerator.h>

Collaboration diagram for onert::ir::train::UseDefGenerator:

Public Member Functions

 UseDefGenerator (void)=delete
 
 UseDefGenerator (const TrainableGraph &tgraph)
 
UseDefChains operator() ()
 
void visit (const train::operation::BinaryArithmetic &node) override
 
void visit (const train::operation::Conv2D &node) override
 
void visit (const train::operation::DepthwiseConv2D &node) override
 
void visit (const train::operation::ElementwiseActivation &node) override
 
void visit (const train::operation::FullyConnected &node) override
 
void visit (const train::operation::Loss &node) override
 
void visit (const train::operation::Pad &node) override
 
void visit (const train::operation::Pool2D &node) override
 
void visit (const train::operation::Reduce &node) override
 
void visit (const train::operation::Reshape &node) override
 
void visit (const train::operation::Softmax &node) override
 
- Public Member Functions inherited from onert::ir::train::UseDefGeneratorBase
virtual ~UseDefGeneratorBase ()=default
 
- Public Member Functions inherited from onert::ir::train::TrainableOperationVisitor
virtual ~TrainableOperationVisitor ()=default
 

Detailed Description

Definition at line 47 of file UseDefGenerator.h.

Constructor & Destructor Documentation

◆ UseDefGenerator() [1/2]

onert::ir::train::UseDefGenerator::UseDefGenerator ( void  )
delete

◆ UseDefGenerator() [2/2]

onert::ir::train::UseDefGenerator::UseDefGenerator ( const TrainableGraph tgraph)

Definition at line 31 of file UseDefGenerator.cc.

32 : _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{}
33{
34 const auto order = _tgraph.topolSortOperations();
35 for (const auto &index : order)
36 {
37 const auto &node = _tgraph.operation(index);
38 assert(_node_to_idx.find(&node) == _node_to_idx.end());
39 _node_to_idx[&node] = index;
40 }
41
42 // Check whether loss exists
43 assert(std::any_of(order.begin(), order.end(),
44 [&](const auto &index) {
45 return _tgraph.operation(index).opcode() == ir::OpCode::Loss;
46 }) &&
47 "Loss does not exist");
48}
const ITrainableOperation & operation(OperationIndex index) const
std::vector< ir::OperationIndex > topolSortOperations() const
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References onert::ir::train::TrainableGraph::operation(), and onert::ir::train::TrainableGraph::topolSortOperations().

Member Function Documentation

◆ operator()()

UseDefChains onert::ir::train::UseDefGenerator::operator() ( )

Definition at line 50 of file UseDefGenerator.cc.

51{
52 const auto &graph = _tgraph.graph();
53 assert(ir::verifier::EdgeChecker().verify(graph));
54
55 _training_usedefs.clear();
56 graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
57 // Initialize as emtpy UseDefChain
58 const auto empty_usedef_chain = UseDefChain{operand};
59 _training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain);
60 _training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain);
61 });
62
63 initForForwardingNodes();
64
65 initForBackwardingNodes();
66
67 return _training_usedefs;
68}
TrainingIndex< OperandIndex > TrainingOperandIndex
Type that provides index of operand for training.
Definition Index.h:128
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33

References onert::ir::train::TrainableGraph::graph().

◆ visit() [1/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::BinaryArithmetic node)
override

Definition at line 70 of file UseDefGenerator.cc.

71{
72 assert(_node_to_idx.find(&node) != _node_to_idx.end());
73 const auto &op_index = _node_to_idx.at(&node);
74 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
75
76 // Insert uses of forwarding output
77 if (node.param().activation != ir::Activation::NONE)
78 {
79 const auto &out_index = node.getOutputs().at(0);
80 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
81 insertUse(out_forwarding_index, backwarding_op_index);
82 }
83
84 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
85 {
86 // Insert use of forwarding inputs
87 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
88 insertUse(in_forwarding_index, backwarding_op_index);
89
90 // Set def of backwarding(backprop) inputs
91 const auto outgoing_index = TrainingOperandIndex{in_index, false};
92 insertBackPropDef(outgoing_index, backwarding_op_index);
93 }
94}
TrainingIndex< OperationIndex > TrainingOperationIndex
Type that provides index of operation node for training.
Definition Index.h:119

References onert::ir::operation::BinaryArithmetic::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::NONE, onert::ir::operation::BinaryArithmetic::param(), and onert::ir::UNDEFINED.

◆ visit() [2/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Conv2D node)
override

Definition at line 96 of file UseDefGenerator.cc.

97{
98 assert(_node_to_idx.find(&node) != _node_to_idx.end());
99 const auto &op_index = _node_to_idx.at(&node);
100 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
101
102 // Insert use of forwarding inputs
103 const auto &in_index = node.getInputs().at(train::operation::Conv2D::Input::INPUT);
104 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
105 insertUse(in_forwarding_index, backwarding_op_index);
106
107 const auto &weights_index = node.getInputs().at(train::operation::Conv2D::Input::KERNEL);
108 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
109 insertUse(weights_forwarding_index, backwarding_op_index);
110
111 // Insert use of forwarding output
112 if (node.param().activation != ir::Activation::NONE)
113 {
114 const auto &out_index = node.getOutputs().at(0);
115 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
116 insertUse(out_forwarding_index, backwarding_op_index);
117 }
118
119 // Set def of backwarding inputs
120 const auto outgoing_index = TrainingOperandIndex{in_index, false};
121 insertBackPropDef(outgoing_index, backwarding_op_index);
122
123 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
124 insertDef(weights_gradient_index, backwarding_op_index);
125
126 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
127 if (bias_index.valid())
128 {
129 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
130 insertDef(bias_gradient_index, backwarding_op_index);
131 }
132}

References onert::ir::operation::Conv2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, onert::ir::NONE, and onert::ir::operation::Conv2D::param().

◆ visit() [3/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::DepthwiseConv2D node)
override

Definition at line 134 of file UseDefGenerator.cc.

135{
136 assert(_node_to_idx.find(&node) != _node_to_idx.end());
137 const auto &op_index = _node_to_idx.at(&node);
138 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
139
140 // Insert use of forwarding inputs
141 const auto &in_index = node.getInputs().at(train::operation::DepthwiseConv2D::Input::INPUT);
142 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
143 insertUse(in_forwarding_index, backwarding_op_index);
144
145 const auto &weights_index = node.getInputs().at(train::operation::DepthwiseConv2D::Input::KERNEL);
146 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
147 insertUse(weights_forwarding_index, backwarding_op_index);
148
149 // Insert uses of forwarding output
150 if (node.param().activation != ir::Activation::NONE)
151 {
152 const auto &out_index = node.getOutputs().at(0);
153 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
154 insertUse(out_forwarding_index, backwarding_op_index);
155 }
156
157 // Set def of backwarding inputs
158 const auto outgoing_index = TrainingOperandIndex{in_index, false};
159 insertBackPropDef(outgoing_index, backwarding_op_index);
160
161 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
162 insertDef(weights_gradient_index, backwarding_op_index);
163
164 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
165 if (bias_index.valid())
166 {
167 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
168 insertDef(bias_gradient_index, backwarding_op_index);
169 }
170}

References onert::ir::operation::DepthwiseConv2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, onert::ir::NONE, and onert::ir::operation::DepthwiseConv2D::param().

◆ visit() [4/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::ElementwiseActivation node)
override

Definition at line 172 of file UseDefGenerator.cc.

173{
174 if (node.param().op_type != operation::ElementwiseActivation::Type::RELU)
175 {
176 throw std::runtime_error{"UseDefGenerator: Not yet supported activation type"};
177 }
178 assert(_node_to_idx.find(&node) != _node_to_idx.end());
179 const auto &op_index = _node_to_idx.at(&node);
180 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
181
182 // Insert use of forwarding output
183 const auto &out_index = node.getOutputs().at(0);
184 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
185 insertUse(out_forwarding_index, backwarding_op_index);
186
187 // Set def of backwarding(backprop) inputs
188 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
189 {
190 const auto outgoing_index = TrainingOperandIndex{in_index, false};
191 insertBackPropDef(outgoing_index, backwarding_op_index);
192 }
193}

References onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseActivation::Param::op_type, onert::ir::operation::ElementwiseActivation::param(), onert::ir::operation::ElementwiseActivation::RELU, and onert::ir::UNDEFINED.

◆ visit() [5/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::FullyConnected node)
override

Definition at line 195 of file UseDefGenerator.cc.

196{
197 assert(_node_to_idx.find(&node) != _node_to_idx.end());
198 const auto &op_index = _node_to_idx.at(&node);
199 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
200
201 // Insert use of forwarding inputs
202 const auto &in_index = node.getInputs().at(train::operation::FullyConnected::Input::INPUT);
203 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
204 insertUse(in_forwarding_index, backwarding_op_index);
205
206 const auto &weights_index = node.getInputs().at(train::operation::FullyConnected::Input::WEIGHT);
207 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
208 insertUse(weights_forwarding_index, backwarding_op_index);
209
210 // Insert uses of forwarding output
211 if (node.param().activation != ir::Activation::NONE)
212 {
213 const auto &out_index = node.getOutputs().at(0);
214 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
215 insertUse(out_forwarding_index, backwarding_op_index);
216 }
217
218 // Set def of backwarding inputs
219 const auto outgoing_index = TrainingOperandIndex{in_index, false};
220 insertBackPropDef(outgoing_index, backwarding_op_index);
221
222 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
223 insertDef(weights_gradient_index, backwarding_op_index);
224
225 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
226 if (bias_index.valid())
227 {
228 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
229 insertDef(bias_gradient_index, backwarding_op_index);
230 }
231}

References onert::ir::operation::FullyConnected::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, onert::ir::NONE, onert::ir::operation::FullyConnected::param(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [6/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Loss node)
override

Definition at line 233 of file UseDefGenerator.cc.

234{
235 assert(_node_to_idx.find(&node) != _node_to_idx.end());
236 const auto &op_index = _node_to_idx.at(&node);
237 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
238
239 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
240 {
241 // Insert use of forwarding inputs
242 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
243 insertUse(in_forwarding_index, backwarding_op_index);
244 }
245
246 // Set def of backwarding(backprop) y_pred
247 const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED);
248 assert(!_tgraph.operands().at(y_pred_index).isConstant());
249 const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false};
250 insertBackPropDef(y_pred_outgoing_index, backwarding_op_index);
251
252 // Set def of backwarding(backprop) y_true
253 const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE);
254 assert(!_tgraph.operands().at(y_true_index).isConstant());
255 const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false};
256 insertBackPropDef(y_true_outgoing_index, backwarding_op_index);
257
258 // Remove use of backwarding output
259 const auto &out_index = node.getOutputs().at(0);
260 const auto incoming_index = TrainingOperandIndex{out_index, false};
261 auto &usedef_chain = _training_usedefs.at(incoming_index);
262 usedef_chain.removeTrainingUse(backwarding_op_index);
263}
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::train::TrainableGraph::operands(), onert::ir::UNDEFINED, onert::ir::operation::Loss::Y_PRED, and onert::ir::operation::Loss::Y_TRUE.

◆ visit() [7/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Pad node)
override

Definition at line 265 of file UseDefGenerator.cc.

266{
267 assert(_node_to_idx.find(&node) != _node_to_idx.end());
268 const auto &op_index = _node_to_idx.at(&node);
269 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
270
271 // Insert use of forwarding pad
272 const auto &pad_index = node.getInputs().at(train::operation::Pad::Input::PAD);
273 const auto pad_forwarding_index = TrainingOperandIndex{pad_index, true};
274 insertUse(pad_forwarding_index, backwarding_op_index);
275
276 // Insert use of backwarding(backprop) output
277 const auto &out_index = node.getOutputs().at(0);
278 const auto incoming_index = TrainingOperandIndex{out_index, false};
279 insertUse(incoming_index, backwarding_op_index);
280
281 // Set def of backwarding(backprop) input
282 const auto &in_index = node.getInputs().at(train::operation::Pad::Input::INPUT);
283 const auto outgoing_index = TrainingOperandIndex{in_index, false};
284 insertBackPropDef(outgoing_index, backwarding_op_index);
285}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pad::INPUT, and onert::ir::operation::Pad::PAD.

◆ visit() [8/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Pool2D node)
override

Definition at line 287 of file UseDefGenerator.cc.

288{
289 if (node.param().op_type != ir::operation::Pool2D::PoolType::MAX &&
290 node.param().op_type != ir::operation::Pool2D::PoolType::AVG)
291 {
292 throw std::runtime_error{"UseDefGenerator: Not yet supported pool type"};
293 }
294
295 assert(_node_to_idx.find(&node) != _node_to_idx.end());
296 const auto &op_index = _node_to_idx.at(&node);
297 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
298
299 // Insert uses of forwarding output
300 if (node.param().activation != ir::Activation::NONE)
301 {
302 const auto &out_index = node.getOutputs().at(0);
303 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
304 insertUse(out_forwarding_index, backwarding_op_index);
305 }
306
307 // Insert use of backwarding(backprop) output
308 const auto &out_index = node.getOutputs().at(0);
309 const auto incoming_index = TrainingOperandIndex{out_index, false};
310 insertUse(incoming_index, backwarding_op_index);
311
312 // Set def of backwarding(backprop) input
313 const auto &in_index = node.getInputs().at(train::operation::Pool2D::Input::INPUT);
314 const auto outgoing_index = TrainingOperandIndex{in_index, false};
315 insertBackPropDef(outgoing_index, backwarding_op_index);
316}

References onert::ir::operation::Pool2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pool2D::AVG, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, onert::ir::operation::Pool2D::MAX, onert::ir::NONE, onert::ir::operation::Pool2D::Param::op_type, and onert::ir::operation::Pool2D::param().

◆ visit() [9/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Reduce node)
override

Definition at line 318 of file UseDefGenerator.cc.

319{
320 if (node.param().reduce_type != ir::operation::Reduce::ReduceType::MEAN)
321 {
322 throw std::runtime_error{"UseDefGenerator: Not yet supported reduce type"};
323 }
324
325 assert(_node_to_idx.find(&node) != _node_to_idx.end());
326 const auto &op_index = _node_to_idx.at(&node);
327 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
328
329 // Insert use of backwarding(backprop) output
330 const auto &out_index = node.getOutputs().at(0);
331 const auto incoming_index = TrainingOperandIndex{out_index, false};
332 insertUse(incoming_index, backwarding_op_index);
333
334 // Set def of backwarding(backprop) input
335 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
336 const auto outgoing_index = TrainingOperandIndex{in_index, false};
337 insertBackPropDef(outgoing_index, backwarding_op_index);
338}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, onert::ir::operation::Reduce::MEAN, onert::ir::operation::Reduce::param(), and onert::ir::operation::Reduce::Param::reduce_type.

◆ visit() [10/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Reshape node)
override

Definition at line 340 of file UseDefGenerator.cc.

341{
342 assert(_node_to_idx.find(&node) != _node_to_idx.end());
343 const auto &op_index = _node_to_idx.at(&node);
344 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
345
346 // Insert use of backwarding(backprop) output
347 const auto &out_index = node.getOutputs().at(0);
348 const auto incoming_index = TrainingOperandIndex{out_index, false};
349 insertUse(incoming_index, backwarding_op_index);
350
351 // Set def of backwarding(backprop) input
352 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
353 const auto outgoing_index = TrainingOperandIndex{in_index, false};
354 insertBackPropDef(outgoing_index, backwarding_op_index);
355}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), and onert::ir::operation::Reduce::INPUT.

◆ visit() [11/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Softmax node)
override

Definition at line 357 of file UseDefGenerator.cc.

358{
359 assert(_node_to_idx.find(&node) != _node_to_idx.end());
360 const auto &op_index = _node_to_idx.at(&node);
361 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
362
363 // Insert uses of forwarding output
364 const auto &out_index = node.getOutputs().at(0);
365 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
366 insertUse(out_forwarding_index, backwarding_op_index);
367
368 // Insert use of backwarding(backprop) output
369 const auto incoming_index = TrainingOperandIndex{out_index, false};
370 insertUse(incoming_index, backwarding_op_index);
371
372 // Set def of backwarding(backprop) input
373 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
374 const auto outgoing_index = TrainingOperandIndex{in_index, false};
375 insertBackPropDef(outgoing_index, backwarding_op_index);
376}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), and onert::ir::operation::Reduce::INPUT.


The documentation for this class was generated from the following files: