ONE - On-device Neural Engine
Loading...
Searching...
No Matches
package.experimental.train.losses.loss.LossFunction Class Reference

Public Member Functions

None __init__ (self, Literal["mean", "sum"] reduction="mean")
 

Detailed Description

Base class for loss functions with reduction type.

Definition at line 5 of file loss.py.

Constructor & Destructor Documentation

◆ __init__()

None package.experimental.train.losses.loss.LossFunction.__init__ (   self,
Literal["mean", "sum"]   reduction = "mean" 
)
Initialize the Categorical Cross-Entropy loss function.
Args:
    reduction (str): Reduction type ('mean', 'sum').

Reimplemented in package.experimental.train.losses.cce.CategoricalCrossentropy, and package.experimental.train.losses.mse.MeanSquaredError.

Definition at line 9 of file loss.py.

9 def __init__(self, reduction: Literal["mean", "sum"] = "mean") -> None:
10 """
11 Initialize the Categorical Cross-Entropy loss function.
12 Args:
13 reduction (str): Reduction type ('mean', 'sum').
14 """
15 reduction_mapping: Dict[Literal["mean", "sum"], loss_reduction] = {
16 "mean": loss_reduction.SUM_OVER_BATCH_SIZE,
17 "sum": loss_reduction.SUM
18 }
19
20 # Validate and assign the reduction type
21 if reduction not in reduction_mapping:
22 raise ValueError(
23 f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")
24
25 self.reduction: loss_reduction = reduction_mapping[reduction]

The documentation for this class was generated from the following file: