9 def __init__(self, reduction: Literal[
"mean",
"sum"] =
"mean") ->
None:
11 Initialize the Categorical Cross-Entropy loss function.
13 reduction (str): Reduction type ('mean', 'sum').
15 reduction_mapping: Dict[Literal[
"mean",
"sum"], loss_reduction] = {
16 "mean": loss_reduction.SUM_OVER_BATCH_SIZE,
17 "sum": loss_reduction.SUM
21 if reduction
not in reduction_mapping:
23 f
"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")
25 self.reduction: loss_reduction = reduction_mapping[reduction]