ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::ir::train::operation::Loss Class Reference

#include <Loss.h>

Collaboration diagram for onert::ir::train::operation::Loss:

Public Member Functions

 Loss (const OperationType &operation, const LossInfo &info, ir::OpCode y_pred_op_code)
 
std::unique_ptr< ITrainableOperationclone () const override
 
void accept (OperationVisitor &v) const override
 
void accept (TrainableOperationVisitor &v) const override
 
bool hasTrainableParameter () const override
 
std::string name () const override
 
const LossInfoparam () const
 
ir::OpCode y_pred_op_code () const
 
- Public Member Functions inherited from onert::ir::operation::Loss
 Loss (const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
 
OpCode opcode () const final
 
- Public Member Functions inherited from onert::ir::Operation
 Operation (OperandConstraint input_constr, const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, OperandConstraint output_constr=OperandConstraint::createAny())
 
 Operation (OperandConstraint input_constr, OperandConstraint output_constr=OperandConstraint::createAny())
 
 Operation (const Operation &)=default
 
 Operation (Operation &&)=default
 
Operationoperator= (const Operation &)=default
 
Operationoperator= (Operation &&)=default
 
virtual ~Operation ()
 
void replaceInputs (const OperandIndex &from, const OperandIndex &to) override
 
void replaceOutputs (const OperandIndex &from, const OperandIndex &to) override
 
OperandIndexSequencegetInputs ()
 
const OperandIndexSequencegetInputs () const override
 
const OperandIndexSequencegetOutputs () const override
 
void setInputs (const OperandIndexSequence &indexes)
 
void setOutputs (const OperandIndexSequence &indexes)
 
- Public Member Functions inherited from onert::ir::IOperation
virtual ~IOperation ()=default
 
- Public Member Functions inherited from onert::ir::train::TrainableOperation
virtual ~TrainableOperation ()=default
 
void disableWeightsUpdate () final
 
void enableWeightsUpdate () final
 
virtual bool isWeightsUpdateEnabled () const final
 
void enableBackward () final
 
void disableBackward () final
 
virtual bool isRequiredForBackward () const final
 
- Public Member Functions inherited from onert::ir::train::ITrainableOperation
virtual ~ITrainableOperation ()=default
 

Additional Inherited Members

- Public Types inherited from onert::ir::operation::Loss
enum  Input { Y_PRED = 0 , Y_TRUE = 1 }
 

Detailed Description

Definition at line 35 of file Loss.h.

Constructor & Destructor Documentation

◆ Loss()

onert::ir::train::operation::Loss::Loss ( const OperationType operation,
const LossInfo info,
ir::OpCode  y_pred_op_code 
)

Definition at line 39 of file Loss.cc.

40 : OperationType{operation.getInputs(), operation.getOutputs()}, _param{param},
41 _y_pred_op_code{y_pred_op_code}
42{
43 // DO NOTHING
44}
OperationType
const LossInfo & param() const
Definition Loss.h:51
ir::OpCode y_pred_op_code() const
Definition Loss.h:52

Member Function Documentation

◆ accept() [1/2]

void onert::ir::train::operation::Loss::accept ( OperationVisitor v) const
overridevirtual

Reimplemented from onert::ir::operation::Loss.

Definition at line 35 of file Loss.cc.

35{ v.visit(*this); }

◆ accept() [2/2]

void onert::ir::train::operation::Loss::accept ( TrainableOperationVisitor v) const
overridevirtual

Implements onert::ir::train::ITrainableOperation.

Definition at line 37 of file Loss.cc.

37{ v.visit(*this); }

◆ clone()

std::unique_ptr< ITrainableOperation > onert::ir::train::operation::Loss::clone ( ) const
overridevirtual

Implements onert::ir::train::ITrainableOperation.

Definition at line 33 of file Loss.cc.

33{ return std::make_unique<Loss>(*this); }

◆ hasTrainableParameter()

bool onert::ir::train::operation::Loss::hasTrainableParameter ( ) const
inlineoverridevirtual

Implements onert::ir::train::ITrainableOperation.

Definition at line 47 of file Loss.h.

47{ return false; }

◆ name()

std::string onert::ir::train::operation::Loss::name ( ) const
inlineoverridevirtual

Reimplemented from onert::ir::IOperation.

Definition at line 48 of file Loss.h.

48{ return toString(_param.loss_code) + toString(opcode()); };
OpCode opcode() const final
Definition Loss.h:44
std::string toString(LossCode opcode)
Convert the optimizer code to the name.
Definition LossCode.cc:28

References onert::ir::train::LossInfo::loss_code, onert::ir::operation::Loss::opcode(), and onert::ir::train::toString().

◆ param()

const LossInfo & onert::ir::train::operation::Loss::param ( ) const
inline

Definition at line 51 of file Loss.h.

51{ return _param; }

Referenced by onert::backend::train::KernelGenerator::visit().

◆ y_pred_op_code()

ir::OpCode onert::ir::train::operation::Loss::y_pred_op_code ( ) const
inline

Definition at line 52 of file Loss.h.

52{ return _y_pred_op_code; }

Referenced by onert::backend::train::KernelGenerator::visit().


The documentation for this class was generated from the following files: