ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::train::pass::LossInsertionPass Class Reference

#include <LossInsertionPass.h>

Collaboration diagram for onert::compiler::train::pass::LossInsertionPass:

Public Member Functions

 LossInsertionPass (ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info, const ir::SubgraphIndex &subg_index)
 
std::string id () final
 
void run () final
 
- Public Member Functions inherited from onert::compiler::train::pass::Pass
 Pass (ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info)
 
virtual ~Pass ()=default
 
- Public Member Functions inherited from onert::compiler::pass::IPass
virtual ~IPass ()=default
 

Additional Inherited Members

- Protected Attributes inherited from onert::compiler::train::pass::Pass
ir::train::TrainableGraph_trainable_graph
 
const ir::train::TrainingInfo_training_info
 

Detailed Description

Definition at line 33 of file LossInsertionPass.h.

Constructor & Destructor Documentation

◆ LossInsertionPass()

onert::compiler::train::pass::LossInsertionPass::LossInsertionPass ( ir::train::TrainableGraph trainable_graph,
const ir::train::TrainingInfo training_info,
const ir::SubgraphIndex subg_index 
)
inline

Definition at line 36 of file LossInsertionPass.h.

39 : Pass{trainable_graph, training_info}, _subg_index{subg_index}
40 {
41 }
Pass(ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info)
Definition Pass.h:46

Member Function Documentation

◆ id()

std::string onert::compiler::train::pass::LossInsertionPass::id ( )
inlinefinalvirtual

Implements onert::compiler::pass::IPass.

Definition at line 44 of file LossInsertionPass.h.

44{ return "LossInsertionPass"; }

◆ run()

void onert::compiler::train::pass::LossInsertionPass::run ( )
finalvirtual

Implements onert::compiler::pass::IPass.

Definition at line 32 of file LossInsertionPass.cc.

33{
34 const auto &loss_info = _training_info->lossInfo();
35
36 if (_trainable_graph.getOutputs().size() != 1)
37 {
38 throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
39 }
40
41 // TODO Consider SparseCategoricalCrossentropy y_true shape
42 // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
43
44 // TODO Implement Loop [0, getOutputs().size())
45 // index: a loop index
46 const auto index = 0;
47 const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
48 const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
49 auto y_true_index = _trainable_graph.addOperand(y_pred.shape(), y_pred.typeInfo());
50 ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
51
52 ir::Shape output_shape;
53 if (loss_info.reduction_type == ir::train::LossReductionType::Sum ||
54 loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize)
55 {
56 output_shape = ir::Shape{1};
57 }
58 else
59 {
60 throw std::runtime_error("LossInsertionPass: Not supported reduction type");
61 }
62
63 const ir::TypeInfo float_op(ir::DataType::FLOAT32);
64 auto output_index = _trainable_graph.addOperand(output_shape, float_op);
65 ir::OperandIndexSequence outputs{output_index};
66
67 // The y_pred node information may be required in some loss layers (e.g.,
68 // CategoricalCrossEntropy(SoftmaxCrossEntropy));
69 const auto &y_pred_node = _trainable_graph.operations().at(y_pred.getDef());
70 const auto y_pred_op_code = y_pred_node.opcode();
71
72 auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs);
73 auto trainable_loss_op =
74 std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info, y_pred_op_code);
75 trainable_loss_op->enableBackward();
76
77 _trainable_graph.addOperation(std::move(trainable_loss_op));
78
79 _trainable_graph.addInput(y_true_index);
80
81 // TODO Add loss as many as output size
83}
ir::train::TrainableGraph & _trainable_graph
Definition Pass.h:53
const ir::train::TrainingInfo * _training_info
Definition Pass.h:54
const OperandIndex & at(IOIndex set_index) const
const OperandIndexSequence & getOutputs() const override
OperationIndex addOperation(std::unique_ptr< ITrainableOperation > &&operation)
Add a new trainable operation to the graph.
OperandIndex addOperand(const Shape &shape, const TypeInfo &type)
const Operations & operations() const override
const Operands & operands() const override
void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind)
void addInput(const OperandIndex &ind, const std::string &name="")
const LossInfo & lossInfo() const
const Object & at(const Index &index) const
Get the object that is associated with the given index.
const luci_interpreter::RuntimeShape output_shape
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
::onert::util::Index< uint32_t, IOIndexTag > IOIndex
Definition Index.h:38

References onert::compiler::train::pass::Pass::_trainable_graph, onert::compiler::train::pass::Pass::_training_info, onert::ir::train::TrainableGraph::addInput(), onert::ir::train::TrainableGraph::addLoss(), onert::ir::train::TrainableGraph::addOperand(), onert::ir::train::TrainableGraph::addOperation(), onert::util::ObjectManager< Index, Object >::at(), onert::ir::OperandIndexSequence::at(), onert::ir::train::TrainableGraph::getOutputs(), onert::ir::train::TrainingInfo::lossInfo(), onert::ir::train::TrainableGraph::operands(), onert::ir::train::TrainableGraph::operations(), output_shape, onert::ir::OperandIndexSequence::size(), onert::ir::train::Sum, and onert::ir::train::SumOverBatchSize.

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: