29 def map_loss_function_to_enum(loss_instance):
30 """
31 Maps a LossFunction instance to the appropriate enum value.
32 Args:
33 loss_instance (BaseLoss): An instance of a loss function.
34 Returns:
35 loss_type: Corresponding enum value for the loss function.
36 Raises:
37 TypeError: If the loss_instance is not a recognized LossFunction type.
38 """
39
40 loss_to_enum = {
41 CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
42 MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
43 }
44 for loss_class, enum_value in loss_to_enum.items():
45 if isinstance(loss_instance, loss_class):
46 return enum_value
47 raise TypeError(
48 f"Unsupported loss function type: {type(loss_instance).__name__}. "
49 f"Supported types are: {list(loss_to_enum.keys())}.")