ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::ir::train::UseDefGenerator Class Reference

#include <UseDefGenerator.h>

Collaboration diagram for onert::ir::train::UseDefGenerator:

Public Member Functions

 UseDefGenerator (void)=delete
 
 UseDefGenerator (const TrainableGraph &tgraph)
 
UseDefChains operator() ()
 
void visit (const train::operation::BinaryArithmetic &node) override
 
void visit (const train::operation::Conv2D &node) override
 
void visit (const train::operation::DepthwiseConv2D &node) override
 
void visit (const train::operation::ElementwiseActivation &node) override
 
void visit (const train::operation::FullyConnected &node) override
 
void visit (const train::operation::Loss &node) override
 
void visit (const train::operation::Pad &node) override
 
void visit (const train::operation::Pool2D &node) override
 
void visit (const train::operation::Reduce &node) override
 
void visit (const train::operation::Reshape &node) override
 
void visit (const train::operation::Softmax &node) override
 
- Public Member Functions inherited from onert::ir::train::UseDefGeneratorBase
virtual ~UseDefGeneratorBase ()=default
 
- Public Member Functions inherited from onert::ir::train::TrainableOperationVisitor
virtual ~TrainableOperationVisitor ()=default
 

Detailed Description

Definition at line 57 of file UseDefGenerator.h.

Constructor & Destructor Documentation

◆ UseDefGenerator() [1/2]

onert::ir::train::UseDefGenerator::UseDefGenerator ( void  )
delete

◆ UseDefGenerator() [2/2]

onert::ir::train::UseDefGenerator::UseDefGenerator ( const TrainableGraph tgraph)

Definition at line 35 of file UseDefGenerator.cc.

36 : _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{}
37{
38 const auto order = _tgraph.topolSortOperations();
39 for (const auto &index : order)
40 {
41 const auto &node = _tgraph.operation(index);
42 assert(_node_to_idx.find(&node) == _node_to_idx.end());
43 _node_to_idx[&node] = index;
44 }
45
46 // Check whether loss exists
47 assert(std::any_of(order.begin(), order.end(),
48 [&](const auto &index) {
49 return _tgraph.operation(index).opcode() == ir::OpCode::Loss;
50 }) &&
51 "Loss does not exist");
52}
const ITrainableOperation & operation(OperationIndex index) const
std::vector< ir::OperationIndex > topolSortOperations() const
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References onert::ir::train::TrainableGraph::operation(), and onert::ir::train::TrainableGraph::topolSortOperations().

Member Function Documentation

◆ operator()()

UseDefChains onert::ir::train::UseDefGenerator::operator() ( )

Definition at line 54 of file UseDefGenerator.cc.

55{
56 const auto &graph = _tgraph.graph();
57 assert(ir::verifier::EdgeChecker().verify(graph));
58
59 _training_usedefs.clear();
60 graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
61 // Initialize as emtpy UseDefChain
62 const auto empty_usedef_chain = UseDefChain{operand};
63 _training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain);
64 _training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain);
65 });
66
67 initForForwardingNodes();
68
69 initForBackwardingNodes();
70
71 return _training_usedefs;
72}
TrainingIndex< OperandIndex > TrainingOperandIndex
Type that provides index of operand for training.
Definition Index.h:132
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:35

References onert::ir::train::TrainableGraph::graph().

◆ visit() [1/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::BinaryArithmetic node)
override

Definition at line 74 of file UseDefGenerator.cc.

75{
76 assert(_node_to_idx.find(&node) != _node_to_idx.end());
77 const auto &op_index = _node_to_idx.at(&node);
78 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
79
80 // Insert uses of forwarding output
81 if (node.param().activation != ir::Activation::NONE)
82 {
83 const auto &out_index = node.getOutputs().at(0);
84 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
85 insertUse(out_forwarding_index, backwarding_op_index);
86 }
87
88 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
89 {
90 // Insert use of forwarding inputs
91 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
92 insertUse(in_forwarding_index, backwarding_op_index);
93
94 // Set def of backwarding(backprop) inputs
95 const auto outgoing_index = TrainingOperandIndex{in_index, false};
96 insertBackPropDef(outgoing_index, backwarding_op_index);
97 }
98}
TrainingIndex< OperationIndex > TrainingOperationIndex
Type that provides index of operation node for training.
Definition Index.h:123

References onert::ir::operation::BinaryArithmetic::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::NONE, onert::ir::operation::BinaryArithmetic::param(), and onert::ir::UNDEFINED.

◆ visit() [2/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Conv2D node)
override

Definition at line 100 of file UseDefGenerator.cc.

101{
102 assert(_node_to_idx.find(&node) != _node_to_idx.end());
103 const auto &op_index = _node_to_idx.at(&node);
104 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
105
106 // Insert use of forwarding inputs
107 const auto &in_index = node.getInputs().at(train::operation::Conv2D::Input::INPUT);
108 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
109 insertUse(in_forwarding_index, backwarding_op_index);
110
111 const auto &weights_index = node.getInputs().at(train::operation::Conv2D::Input::KERNEL);
112 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
113 insertUse(weights_forwarding_index, backwarding_op_index);
114
115 // Insert use of forwarding output
116 if (node.param().activation != ir::Activation::NONE)
117 {
118 const auto &out_index = node.getOutputs().at(0);
119 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
120 insertUse(out_forwarding_index, backwarding_op_index);
121 }
122
123 // Set def of backwarding inputs
124 const auto outgoing_index = TrainingOperandIndex{in_index, false};
125 insertBackPropDef(outgoing_index, backwarding_op_index);
126
127 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
128 insertDef(weights_gradient_index, backwarding_op_index);
129
130 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
131 if (bias_index.valid())
132 {
133 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
134 insertDef(bias_gradient_index, backwarding_op_index);
135 }
136}

References onert::ir::operation::Conv2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Conv2D::INPUT, onert::ir::operation::Conv2D::KERNEL, onert::ir::NONE, and onert::ir::operation::Conv2D::param().

◆ visit() [3/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::DepthwiseConv2D node)
override

Definition at line 138 of file UseDefGenerator.cc.

139{
140 assert(_node_to_idx.find(&node) != _node_to_idx.end());
141 const auto &op_index = _node_to_idx.at(&node);
142 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
143
144 // Insert use of forwarding inputs
145 const auto &in_index = node.getInputs().at(train::operation::DepthwiseConv2D::Input::INPUT);
146 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
147 insertUse(in_forwarding_index, backwarding_op_index);
148
149 const auto &weights_index = node.getInputs().at(train::operation::DepthwiseConv2D::Input::KERNEL);
150 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
151 insertUse(weights_forwarding_index, backwarding_op_index);
152
153 // Insert uses of forwarding output
154 if (node.param().activation != ir::Activation::NONE)
155 {
156 const auto &out_index = node.getOutputs().at(0);
157 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
158 insertUse(out_forwarding_index, backwarding_op_index);
159 }
160
161 // Set def of backwarding inputs
162 const auto outgoing_index = TrainingOperandIndex{in_index, false};
163 insertBackPropDef(outgoing_index, backwarding_op_index);
164
165 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
166 insertDef(weights_gradient_index, backwarding_op_index);
167
168 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
169 if (bias_index.valid())
170 {
171 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
172 insertDef(bias_gradient_index, backwarding_op_index);
173 }
174}

References onert::ir::operation::DepthwiseConv2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::DepthwiseConv2D::INPUT, onert::ir::operation::DepthwiseConv2D::KERNEL, onert::ir::NONE, and onert::ir::operation::DepthwiseConv2D::param().

◆ visit() [4/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::ElementwiseActivation node)
override

Definition at line 176 of file UseDefGenerator.cc.

177{
178 if (node.param().op_type != operation::ElementwiseActivation::Type::RELU)
179 {
180 throw std::runtime_error{"UseDefGenerator: Not yet supported activation type"};
181 }
182 assert(_node_to_idx.find(&node) != _node_to_idx.end());
183 const auto &op_index = _node_to_idx.at(&node);
184 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
185
186 // Insert use of forwarding output
187 const auto &out_index = node.getOutputs().at(0);
188 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
189 insertUse(out_forwarding_index, backwarding_op_index);
190
191 // Set def of backwarding(backprop) inputs
192 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
193 {
194 const auto outgoing_index = TrainingOperandIndex{in_index, false};
195 insertBackPropDef(outgoing_index, backwarding_op_index);
196 }
197}

References onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::ElementwiseActivation::Param::op_type, onert::ir::operation::ElementwiseActivation::param(), onert::ir::operation::ElementwiseActivation::RELU, and onert::ir::UNDEFINED.

◆ visit() [5/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::FullyConnected node)
override

Definition at line 199 of file UseDefGenerator.cc.

200{
201 assert(_node_to_idx.find(&node) != _node_to_idx.end());
202 const auto &op_index = _node_to_idx.at(&node);
203 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
204
205 // Insert use of forwarding inputs
206 const auto &in_index = node.getInputs().at(train::operation::FullyConnected::Input::INPUT);
207 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
208 insertUse(in_forwarding_index, backwarding_op_index);
209
210 const auto &weights_index = node.getInputs().at(train::operation::FullyConnected::Input::WEIGHT);
211 const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
212 insertUse(weights_forwarding_index, backwarding_op_index);
213
214 // Insert uses of forwarding output
215 if (node.param().activation != ir::Activation::NONE)
216 {
217 const auto &out_index = node.getOutputs().at(0);
218 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
219 insertUse(out_forwarding_index, backwarding_op_index);
220 }
221
222 // Set def of backwarding inputs
223 const auto outgoing_index = TrainingOperandIndex{in_index, false};
224 insertBackPropDef(outgoing_index, backwarding_op_index);
225
226 const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
227 insertDef(weights_gradient_index, backwarding_op_index);
228
229 const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
230 if (bias_index.valid())
231 {
232 const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
233 insertDef(bias_gradient_index, backwarding_op_index);
234 }
235}

References onert::ir::operation::FullyConnected::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Conv2D::BIAS, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::FullyConnected::INPUT, onert::ir::NONE, onert::ir::operation::FullyConnected::param(), and onert::ir::operation::FullyConnected::WEIGHT.

◆ visit() [6/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Loss node)
override

Definition at line 237 of file UseDefGenerator.cc.

238{
239 assert(_node_to_idx.find(&node) != _node_to_idx.end());
240 const auto &op_index = _node_to_idx.at(&node);
241 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
242
243 for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
244 {
245 // Insert use of forwarding inputs
246 const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
247 insertUse(in_forwarding_index, backwarding_op_index);
248 }
249
250 // Set def of backwarding(backprop) y_pred
251 const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED);
252 assert(!_tgraph.operands().at(y_pred_index).isConstant());
253 const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false};
254 insertBackPropDef(y_pred_outgoing_index, backwarding_op_index);
255
256 // Set def of backwarding(backprop) y_true
257 const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE);
258 assert(!_tgraph.operands().at(y_true_index).isConstant());
259 const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false};
260 insertBackPropDef(y_true_outgoing_index, backwarding_op_index);
261
262 // Remove use of backwarding output
263 const auto &out_index = node.getOutputs().at(0);
264 const auto incoming_index = TrainingOperandIndex{out_index, false};
265 auto &usedef_chain = _training_usedefs.at(incoming_index);
266 usedef_chain.removeTrainingUse(backwarding_op_index);
267}
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.

References onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::DUPLICATED, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::train::TrainableGraph::operands(), onert::ir::UNDEFINED, onert::ir::operation::Loss::Y_PRED, and onert::ir::operation::Loss::Y_TRUE.

◆ visit() [7/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Pad node)
override

Definition at line 269 of file UseDefGenerator.cc.

270{
271 assert(_node_to_idx.find(&node) != _node_to_idx.end());
272 const auto &op_index = _node_to_idx.at(&node);
273 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
274
275 // Insert use of forwarding pad
276 const auto &pad_index = node.getInputs().at(train::operation::Pad::Input::PAD);
277 const auto pad_forwarding_index = TrainingOperandIndex{pad_index, true};
278 insertUse(pad_forwarding_index, backwarding_op_index);
279
280 // Insert use of backwarding(backprop) output
281 const auto &out_index = node.getOutputs().at(0);
282 const auto incoming_index = TrainingOperandIndex{out_index, false};
283 insertUse(incoming_index, backwarding_op_index);
284
285 // Set def of backwarding(backprop) input
286 const auto &in_index = node.getInputs().at(train::operation::Pad::Input::INPUT);
287 const auto outgoing_index = TrainingOperandIndex{in_index, false};
288 insertBackPropDef(outgoing_index, backwarding_op_index);
289}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pad::INPUT, and onert::ir::operation::Pad::PAD.

◆ visit() [8/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Pool2D node)
override

Definition at line 291 of file UseDefGenerator.cc.

292{
293 if (node.param().op_type != ir::operation::Pool2D::PoolType::MAX &&
294 node.param().op_type != ir::operation::Pool2D::PoolType::AVG)
295 {
296 throw std::runtime_error{"UseDefGenerator: Not yet supported pool type"};
297 }
298
299 assert(_node_to_idx.find(&node) != _node_to_idx.end());
300 const auto &op_index = _node_to_idx.at(&node);
301 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
302
303 // Insert uses of forwarding output
304 if (node.param().activation != ir::Activation::NONE)
305 {
306 const auto &out_index = node.getOutputs().at(0);
307 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
308 insertUse(out_forwarding_index, backwarding_op_index);
309 }
310
311 // Insert use of backwarding(backprop) output
312 const auto &out_index = node.getOutputs().at(0);
313 const auto incoming_index = TrainingOperandIndex{out_index, false};
314 insertUse(incoming_index, backwarding_op_index);
315
316 // Set def of backwarding(backprop) input
317 const auto &in_index = node.getInputs().at(train::operation::Pool2D::Input::INPUT);
318 const auto outgoing_index = TrainingOperandIndex{in_index, false};
319 insertBackPropDef(outgoing_index, backwarding_op_index);
320}

References onert::ir::operation::Pool2D::Param::activation, onert::ir::OperandIndexSequence::at(), onert::ir::operation::Pool2D::AVG, onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Pool2D::INPUT, onert::ir::operation::Pool2D::MAX, onert::ir::NONE, onert::ir::operation::Pool2D::Param::op_type, and onert::ir::operation::Pool2D::param().

◆ visit() [9/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Reduce node)
override

Definition at line 322 of file UseDefGenerator.cc.

323{
324 if (node.param().reduce_type != ir::operation::Reduce::ReduceType::MEAN)
325 {
326 throw std::runtime_error{"UseDefGenerator: Not yet supported reduce type"};
327 }
328
329 assert(_node_to_idx.find(&node) != _node_to_idx.end());
330 const auto &op_index = _node_to_idx.at(&node);
331 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
332
333 // Insert use of backwarding(backprop) output
334 const auto &out_index = node.getOutputs().at(0);
335 const auto incoming_index = TrainingOperandIndex{out_index, false};
336 insertUse(incoming_index, backwarding_op_index);
337
338 // Set def of backwarding(backprop) input
339 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
340 const auto outgoing_index = TrainingOperandIndex{in_index, false};
341 insertBackPropDef(outgoing_index, backwarding_op_index);
342}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), onert::ir::operation::Reduce::INPUT, onert::ir::operation::Reduce::MEAN, onert::ir::operation::Reduce::param(), and onert::ir::operation::Reduce::Param::reduce_type.

◆ visit() [10/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Reshape node)
override

Definition at line 344 of file UseDefGenerator.cc.

345{
346 assert(_node_to_idx.find(&node) != _node_to_idx.end());
347 const auto &op_index = _node_to_idx.at(&node);
348 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
349
350 // Insert use of backwarding(backprop) output
351 const auto &out_index = node.getOutputs().at(0);
352 const auto incoming_index = TrainingOperandIndex{out_index, false};
353 insertUse(incoming_index, backwarding_op_index);
354
355 // Set def of backwarding(backprop) input
356 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
357 const auto outgoing_index = TrainingOperandIndex{in_index, false};
358 insertBackPropDef(outgoing_index, backwarding_op_index);
359}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), and onert::ir::operation::Reduce::INPUT.

◆ visit() [11/11]

void onert::ir::train::UseDefGenerator::visit ( const train::operation::Softmax node)
override

Definition at line 361 of file UseDefGenerator.cc.

362{
363 assert(_node_to_idx.find(&node) != _node_to_idx.end());
364 const auto &op_index = _node_to_idx.at(&node);
365 const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
366
367 // Insert uses of forwarding output
368 const auto &out_index = node.getOutputs().at(0);
369 const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
370 insertUse(out_forwarding_index, backwarding_op_index);
371
372 // Insert use of backwarding(backprop) output
373 const auto incoming_index = TrainingOperandIndex{out_index, false};
374 insertUse(incoming_index, backwarding_op_index);
375
376 // Set def of backwarding(backprop) input
377 const auto &in_index = node.getInputs().at(train::operation::Reduce::Input::INPUT);
378 const auto outgoing_index = TrainingOperandIndex{in_index, false};
379 insertBackPropDef(outgoing_index, backwarding_op_index);
380}

References onert::ir::OperandIndexSequence::at(), onert::ir::Operation::getInputs(), onert::ir::Operation::getOutputs(), and onert::ir::operation::Reduce::INPUT.


The documentation for this class was generated from the following files: