ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::train::pass::LossInsertionPass Class Reference

#include <LossInsertionPass.h>

Collaboration diagram for onert::compiler::train::pass::LossInsertionPass:

Public Member Functions

 LossInsertionPass (ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info, const ir::SubgraphIndex &subg_index)
 
std::string id () final
 
void run () final
 
- Public Member Functions inherited from onert::compiler::train::pass::Pass
 Pass (ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info)
 
virtual ~Pass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::IPass
virtual ~IPass ()=default
 

Additional Inherited Members

- Protected Attributes inherited from onert::compiler::train::pass::Pass
ir::train::TrainableGraph_trainable_graph
 
const ir::train::TrainingInfo_training_info
 

Detailed Description

Definition at line 27 of file LossInsertionPass.h.

Constructor & Destructor Documentation

◆ LossInsertionPass()

onert::compiler::train::pass::LossInsertionPass::LossInsertionPass ( ir::train::TrainableGraph trainable_graph,
const ir::train::TrainingInfo training_info,
const ir::SubgraphIndex subg_index 
)
inline

Definition at line 30 of file LossInsertionPass.h.

33 : Pass{trainable_graph, training_info}, _subg_index{subg_index}
34 {
35 }
Pass(ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info)
Definition Pass.h:34

Member Function Documentation

◆ id()

std::string onert::compiler::train::pass::LossInsertionPass::id ( )
inlinefinalvirtual

Implements onert::compiler::pass::IPass.

Definition at line 38 of file LossInsertionPass.h.

38{ return "LossInsertionPass"; }

◆ run()

void onert::compiler::train::pass::LossInsertionPass::run ( )
finalvirtual

Implements onert::compiler::pass::IPass.

Definition at line 26 of file LossInsertionPass.cc.

27{
28 const auto &loss_info = _training_info->lossInfo();
29
30 if (_trainable_graph.getOutputs().size() != 1)
31 {
32 throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
33 }
34
35 // TODO Consider SparseCategoricalCrossentropy y_true shape
36 // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
37
38 // TODO Implement Loop [0, getOutputs().size())
39 // index: a loop index
40 const auto index = 0;
41 const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
42 const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
43 auto y_true_index = _trainable_graph.addOperand(y_pred.shape(), y_pred.typeInfo());
44 ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
45
46 ir::Shape output_shape;
47 if (loss_info.reduction_type == ir::train::LossReductionType::Sum ||
48 loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize)
49 {
50 output_shape = ir::Shape{1};
51 }
52 else
53 {
54 throw std::runtime_error("LossInsertionPass: Not supported reduction type");
55 }
56
57 const ir::TypeInfo float_op(ir::DataType::FLOAT32);
58 auto output_index = _trainable_graph.addOperand(output_shape, float_op);
59 ir::OperandIndexSequence outputs{output_index};
60
61 // The y_pred node information may be required in some loss layers (e.g.,
62 // CategoricalCrossEntropy(SoftmaxCrossEntropy));
63 const auto &y_pred_node = _trainable_graph.operations().at(y_pred.getDef());
64 const auto y_pred_op_code = y_pred_node.opcode();
65
66 auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs);
67 auto trainable_loss_op =
68 std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info, y_pred_op_code);
69 trainable_loss_op->enableBackward();
70
71 _trainable_graph.addOperation(std::move(trainable_loss_op));
72
73 _trainable_graph.addInput(y_true_index);
74
75 // TODO Add loss as many as output size
77}
ir::train::TrainableGraph & _trainable_graph
Definition Pass.h:41
const ir::train::TrainingInfo * _training_info
Definition Pass.h:42
const OperandIndex & at(IOIndex set_index) const
const OperandIndexSequence & getOutputs() const override
OperationIndex addOperation(std::unique_ptr< ITrainableOperation > &&operation)
Add a new trainable operation to the graph.
OperandIndex addOperand(const Shape &shape, const TypeInfo &type)
const Operations & operations() const override
const Operands & operands() const override
void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind)
void addInput(const OperandIndex &ind, const std::string &name="")
const LossInfo & lossInfo() const
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
::onert::util::Index< uint32_t, IOIndexTag > IOIndex
Definition Index.h:36

References onert::compiler::train::pass::Pass::_trainable_graph, onert::compiler::train::pass::Pass::_training_info, onert::ir::train::TrainableGraph::addInput(), onert::ir::train::TrainableGraph::addLoss(), onert::ir::train::TrainableGraph::addOperand(), onert::ir::train::TrainableGraph::addOperation(), onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::train::TrainableGraph::getOutputs(), onert::ir::train::TrainingInfo::lossInfo(), onert::ir::train::TrainableGraph::operands(), onert::ir::train::TrainableGraph::operations(), output_shape, onert::ir::OperandIndexSequence::size(), onert::ir::train::Sum, and onert::ir::train::SumOverBatchSize.


The documentation for this class was generated from the following files: