ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
losses.registry.LossRegistry Class Reference

Static Public Member Functions

 create_loss (name)
 
 map_loss_function_to_enum (loss_instance)
 

Static Protected Attributes

dict _losses
 

Detailed Description

Registry for creating and mapping losses by name or instance.

Definition at line 6 of file registry.py.

Member Function Documentation

◆ create_loss()

losses.registry.LossRegistry.create_loss (   name)
static
Create a loss instance by name.
Args:
    name (str): Name of the loss.
Returns:
    BaseLoss: Loss instance.

Definition at line 16 of file registry.py.

16 def create_loss(name):
17 """
18 Create a loss instance by name.
19 Args:
20 name (str): Name of the loss.
21 Returns:
22 BaseLoss: Loss instance.
23 """
24 if name not in LossRegistry._losses:
25 raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet")
26 return LossRegistry._losses[name]()
27

◆ map_loss_function_to_enum()

losses.registry.LossRegistry.map_loss_function_to_enum (   loss_instance)
static
Maps a LossFunction instance to the appropriate enum value.
Args:
    loss_instance (BaseLoss): An instance of a loss function.
Returns:
    loss_type: Corresponding enum value for the loss function.
Raises:
    TypeError: If the loss_instance is not a recognized LossFunction type.

Definition at line 29 of file registry.py.

29 def map_loss_function_to_enum(loss_instance):
30 """
31 Maps a LossFunction instance to the appropriate enum value.
32 Args:
33 loss_instance (BaseLoss): An instance of a loss function.
34 Returns:
35 loss_type: Corresponding enum value for the loss function.
36 Raises:
37 TypeError: If the loss_instance is not a recognized LossFunction type.
38 """
39 # Loss to Enum mapping
40 loss_to_enum = {
41 CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
42 MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
43 }
44 for loss_class, enum_value in loss_to_enum.items():
45 if isinstance(loss_instance, loss_class):
46 return enum_value
47 raise TypeError(
48 f"Unsupported loss function type: {type(loss_instance).__name__}. "
49 f"Supported types are: {list(loss_to_enum.keys())}.")

Field Documentation

◆ _losses

dict losses.registry.LossRegistry._losses
staticprotected
Initial value:
= {
"categorical_crossentropy": CategoricalCrossentropy,
"mean_squared_error": MeanSquaredError
}

Definition at line 10 of file registry.py.


The documentation for this class was generated from the following file: