ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
losses.loss.LossFunction Class Reference

Public Member Functions

 __init__ (self, reduction="mean")
 

Data Fields

 reduction
 

Detailed Description

Base class for loss functions with reduction type.

Definition at line 4 of file loss.py.

Constructor & Destructor Documentation

◆ __init__()

losses.loss.LossFunction.__init__ (   self,
  reduction = "mean" 
)
Initialize the Categorical Cross-Entropy loss function.
Args:
    reduction (str): Reduction type ('mean', 'sum').

Reimplemented in losses.cce.CategoricalCrossentropy, and losses.mse.MeanSquaredError.

Definition at line 8 of file loss.py.

8 def __init__(self, reduction="mean"):
9 """
10 Initialize the Categorical Cross-Entropy loss function.
11 Args:
12 reduction (str): Reduction type ('mean', 'sum').
13 """
14 reduction_mapping = {
15 "mean": loss_reduction.SUM_OVER_BATCH_SIZE,
16 "sum": loss_reduction.SUM
17 }
18
19 # Validate and assign the reduction type
20 if reduction not in reduction_mapping:
21 raise ValueError(
22 f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")
23
24 self.reduction = reduction_mapping[reduction]

Field Documentation

◆ reduction

losses.loss.LossFunction.reduction

Definition at line 24 of file loss.py.


The documentation for this class was generated from the following file: