ONE - On-device Neural Engine
Loading...
Searching...
No Matches
registry.py
Go to the documentation of this file.
1from typing import Type, Dict
2from .loss import LossFunction
3from .cce import CategoricalCrossentropy
4from .mse import MeanSquaredError
5from onert.native.libnnfw_api_pybind import loss as loss_type
6
7
9 """
10 Registry for creating and mapping losses by name or instance.
11 """
12 _losses: Dict[str, Type[LossFunction]] = {
13 "categorical_crossentropy": CategoricalCrossentropy,
14 "mean_squared_error": MeanSquaredError
15 }
16
17 @staticmethod
18 def create_loss(name: str) -> LossFunction:
19 """
20 Create a loss instance by name.
21 Args:
22 name (str): Name of the loss.
23 Returns:
24 BaseLoss: Loss instance.
25 """
26 if name not in LossRegistry._losses:
27 raise ValueError(f"Unknown Loss: {name}. Custom loss is not supported yet")
28 return LossRegistry._losses[name]()
29
30 @staticmethod
31 def map_loss_function_to_enum(loss_instance: LossFunction) -> loss_type:
32 """
33 Maps a LossFunction instance to the appropriate enum value.
34 Args:
35 loss_instance (BaseLoss): An instance of a loss function.
36 Returns:
37 loss_type: Corresponding enum value for the loss function.
38 Raises:
39 TypeError: If the loss_instance is not a recognized LossFunction type.
40 """
41 loss_to_enum: Dict[Type[LossFunction], loss_type] = {
42 CategoricalCrossentropy: loss_type.CATEGORICAL_CROSSENTROPY,
43 MeanSquaredError: loss_type.MEAN_SQUARED_ERROR
44 }
45 for loss_class, enum_value in loss_to_enum.items():
46 if isinstance(loss_instance, loss_class):
47 return enum_value
48 raise TypeError(
49 f"Unsupported loss function type: {type(loss_instance).__name__}. "
50 f"Supported types are: {list(loss_to_enum.keys())}.")
loss_type map_loss_function_to_enum(LossFunction loss_instance)
Definition registry.py:31