ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
metrics.categorical_accuracy.CategoricalAccuracy Class Reference
Collaboration diagram for metrics.categorical_accuracy.CategoricalAccuracy:

Public Member Functions

 __init__ (self)
 
 reset_state (self)
 
 update_state (self, outputs, expecteds)
 
 result (self)
 

Data Fields

 correct
 
 total
 
 axis
 

Detailed Description

Metric for computing categorical accuracy.

Definition at line 5 of file categorical_accuracy.py.

Constructor & Destructor Documentation

◆ __init__()

metrics.categorical_accuracy.CategoricalAccuracy.__init__ (   self)

Definition at line 9 of file categorical_accuracy.py.

9 def __init__(self):
10 self.correct = 0
11 self.total = 0
12 self.axis = 0
13

Member Function Documentation

◆ reset_state()

metrics.categorical_accuracy.CategoricalAccuracy.reset_state (   self)
Reset the metric's state.

Reimplemented from metrics.metric.Metric.

Definition at line 14 of file categorical_accuracy.py.

14 def reset_state(self):
15 """
16 Reset the metric's state.
17 """
18 self.correct = 0
19 self.total = 0
20

References metrics.categorical_accuracy.CategoricalAccuracy.correct, and metrics.categorical_accuracy.CategoricalAccuracy.total.

◆ result()

metrics.categorical_accuracy.CategoricalAccuracy.result (   self)
Compute and return the final metric value.
Returns:
    float: Metric value.

Reimplemented from metrics.metric.Metric.

Definition at line 48 of file categorical_accuracy.py.

48 def result(self):
49 """
50 Compute and return the final metric value.
51 Returns:
52 float: Metric value.
53 """
54 if self.total == 0:
55 return 0.0
56 return self.correct / self.total

References metrics.categorical_accuracy.CategoricalAccuracy.correct, and metrics.categorical_accuracy.CategoricalAccuracy.total.

◆ update_state()

metrics.categorical_accuracy.CategoricalAccuracy.update_state (   self,
  outputs,
  expecteds 
)
Update the metric's state based on the outputs and expecteds.
Args:
    outputs (list of np.ndarray): List of model outputs for each output layer.
    expecteds (list of np.ndarray): List of expected ground truth values for each output layer.

Reimplemented from metrics.metric.Metric.

Definition at line 21 of file categorical_accuracy.py.

21 def update_state(self, outputs, expecteds):
22 """
23 Update the metric's state based on the outputs and expecteds.
24 Args:
25 outputs (list of np.ndarray): List of model outputs for each output layer.
26 expecteds (list of np.ndarray): List of expected ground truth values for each output layer.
27 """
28 if len(outputs) != len(expecteds):
29 raise ValueError(
30 "The number of outputs and expecteds must match. "
31 f"Got {len(outputs)} outputs and {len(expecteds)} expecteds.")
32
33 for output, expected in zip(outputs, expecteds):
34 if output.shape[self.axis] != expected.shape[self.axis]:
35 raise ValueError(
36 f"Output and expected shapes must match along the specified axis {self.axis}. "
37 f"Got output shape {output.shape} and expected shape {expected.shape}."
38 )
39
40 batch_size = output.shape[self.axis]
41 for b in range(batch_size):
42 output_idx = np.argmax(output[b])
43 expected_idx = np.argmax(expected[b])
44 if output_idx == expected_idx:
45 self.correct += 1
46 self.total += batch_size
47

References luci_interpreter::kernels::CumSum.axis(), coco::ConcatF.axis(), ShapeQuery.axis(), luci_interpreter::ConcatenationParams.axis, luci_interpreter::GatherParams.axis, luci_interpreter::OneHotParams.axis, luci_interpreter::PackParams.axis, luci_interpreter::UnpackParams.axis, luci_interpreter::kernels::ArgMax.axis(), luci_interpreter::kernels::ExpandDims.axis(), luci_interpreter::kernels::Split.axis(), luci_interpreter::kernels::SplitV.axis(), loco::Permutation< Domain::DepthwiseFilter >.axis(), loco::Permutation< Domain::DepthwiseFilter >.axis(), loco::Permutation< Domain::Feature >.axis(), loco::Permutation< Domain::Feature >.axis(), loco::Permutation< Domain::Filter >.axis(), loco::Permutation< Domain::Filter >.axis(), loco::Permutation< Domain::Matrix >.axis(), loco::Permutation< Domain::Matrix >.axis(), luci::CircleBCQGather.axis(), luci::CircleConcatenation.axis(), luci::CircleGather.axis(), luci::CircleOneHot.axis(), luci::CirclePack.axis(), luci::CircleUnpack.axis(), moco::TFPack.axis(), luci::CircleCumSum.axis(), luci::CircleExpandDims.axis(), luci::CircleReverseV2.axis(), moco::TFConcatV2.axis(), loco::TensorTranspose::Perm.axis(), loco::TensorTranspose::Perm.axis(), locoex::TFLConcatenation.axis(), loco::Softmax< Domain::Tensor >.axis(), loco::BiasAdd< Domain::Tensor >.axis(), moco::TensorPackEnumerator.axis(), loco::TensorConcat.axis(), coco::ConcatF.axis(), locoex::TFLConcatenation.axis(), loco::TensorConcat.axis(), loco::Softmax< Domain::Tensor >.axis(), loco::BiasAdd< Domain::Tensor >.axis(), luci::CircleBCQGather.axis(), luci::CircleConcatenation.axis(), luci::CircleCumSum.axis(), luci::CircleExpandDims.axis(), luci::CircleGather.axis(), luci::CircleOneHot.axis(), luci::CirclePack.axis(), luci::CircleReverseV2.axis(), luci::CircleUnpack.axis(), moco::TFConcatV2.axis(), moco::TFPack.axis(), luci_interpreter_pal::MeanParams.axis, luci_interpreter_pal::ConcatenationParams.axis, onert_micro::core::ConcatenationParams.axis, nnfw::cker::SoftmaxParams.axis, nnfw::cker::PackParams.axis, nnfw::cker::UnpackParams.axis, nnfw::cker::ConcatenationParams.axis, nnfw::cker::GatherParams.axis, nnfw::cker::SplitParams.axis, nnfw::cker::SplitVParams.axis, metrics.categorical_accuracy.CategoricalAccuracy.axis, onert::ir::operation::BCQGather::Param.axis, onert::ir::operation::Concat::Param.axis, onert::ir::operation::Gather::Param.axis, onert::ir::operation::LogSoftmax::Param.axis, onert::ir::operation::OneHot::Param.axis, onert::ir::operation::Pack::Param.axis, onert::ir::operation::Unpack::Param.axis, onert::ir::train::CategoricalCrossentropyParam.axis, metrics.categorical_accuracy.CategoricalAccuracy.correct, and metrics.categorical_accuracy.CategoricalAccuracy.total.

Field Documentation

◆ axis

metrics.categorical_accuracy.CategoricalAccuracy.axis

◆ correct

◆ total


The documentation for this class was generated from the following file: