31 def map_loss_function_to_enum(loss_instance: LossFunction) -> loss_type:
32 """
33 Maps a LossFunction instance to the appropriate enum value.
34 Args:
35 loss_instance (BaseLoss): An instance of a loss function.
36 Returns:
37 loss_type: Corresponding enum value for the loss function.
38 Raises:
39 TypeError: If the loss_instance is not a recognized LossFunction type.
40 """
41 loss_to_enum: Dict[Type[LossFunction], loss_type] = {
42 CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
43 MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
44 }
45 for loss_class, enum_value in loss_to_enum.items():
46 if isinstance(loss_instance, loss_class):
47 return enum_value
48 raise TypeError(
49 f"Unsupported loss function type: {type(loss_instance).__name__}. "
50 f"Supported types are: {list(loss_to_enum.keys())}.")