ONE - On-device Neural Engine
Loading...
Searching...
No Matches
loss.py
Go to the documentation of this file.
1from typing import Literal, Dict
2from onert.native.libnnfw_api_pybind import loss_reduction
3
4
6 """
7 Base class for loss functions with reduction type.
8 """
9 def __init__(self, reduction: Literal["mean", "sum"] = "mean") -> None:
10 """
11 Initialize the Categorical Cross-Entropy loss function.
12 Args:
13 reduction (str): Reduction type ('mean', 'sum').
14 """
15 reduction_mapping: Dict[Literal["mean", "sum"], loss_reduction] = {
16 "mean": loss_reduction.SUM_OVER_BATCH_SIZE,
17 "sum": loss_reduction.SUM
18 }
19
20 # Validate and assign the reduction type
21 if reduction not in reduction_mapping:
22 raise ValueError(
23 f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")
24
25 self.reduction: loss_reduction = reduction_mapping[reduction]
None __init__(self, Literal["mean", "sum"] reduction="mean")
Definition loss.py:9