ONE - On-device Neural Engine
Loading...
Searching...
No Matches
package.experimental.train.losses.registry.LossRegistry Class Reference

Static Public Member Functions

LossFunction create_loss (str name)
 
loss_type map_loss_function_to_enum (LossFunction loss_instance)
 

Static Protected Attributes

dict _losses
 

Detailed Description

Registry for creating and mapping losses by name or instance.

Definition at line 8 of file registry.py.

Member Function Documentation

◆ create_loss()

LossFunction package.experimental.train.losses.registry.LossRegistry.create_loss ( str  name)
static
Create a loss instance by name.
Args:
    name (str): Name of the loss.
Returns:
    BaseLoss: Loss instance.

Definition at line 18 of file registry.py.

18 def create_loss(name: str) -> LossFunction:
19 """
20 Create a loss instance by name.
21 Args:
22 name (str): Name of the loss.
23 Returns:
24 BaseLoss: Loss instance.
25 """
26 if name not in LossRegistry._losses:
27 raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet")
28 return LossRegistry._losses[name]()
29

◆ map_loss_function_to_enum()

loss_type package.experimental.train.losses.registry.LossRegistry.map_loss_function_to_enum ( LossFunction  loss_instance)
static
Maps a LossFunction instance to the appropriate enum value.
Args:
    loss_instance (BaseLoss): An instance of a loss function.
Returns:
    loss_type: Corresponding enum value for the loss function.
Raises:
    TypeError: If the loss_instance is not a recognized LossFunction type.

Definition at line 31 of file registry.py.

31 def map_loss_function_to_enum(loss_instance: LossFunction) -> loss_type:
32 """
33 Maps a LossFunction instance to the appropriate enum value.
34 Args:
35 loss_instance (BaseLoss): An instance of a loss function.
36 Returns:
37 loss_type: Corresponding enum value for the loss function.
38 Raises:
39 TypeError: If the loss_instance is not a recognized LossFunction type.
40 """
41 loss_to_enum: Dict[Type[LossFunction], loss_type] = {
42 CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
43 MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
44 }
45 for loss_class, enum_value in loss_to_enum.items():
46 if isinstance(loss_instance, loss_class):
47 return enum_value
48 raise TypeError(
49 f"Unsupported loss function type: {type(loss_instance).__name__}. "
50 f"Supported types are: {list(loss_to_enum.keys())}.")

Field Documentation

◆ _losses

dict package.experimental.train.losses.registry.LossRegistry._losses
staticprotected
Initial value:
= {
"categorical_crossentropy": CategoricalCrossentropy,
"mean_squared_error": MeanSquaredError
}

Definition at line 12 of file registry.py.


The documentation for this class was generated from the following file: