26 expecteds: List[np.ndarray]) ->
None:
28 Update the metric's state based on the outputs and expecteds.
31 outputs (list of np.ndarray): List of model outputs for each output layer.
32 expecteds (list of np.ndarray): List of expected ground truth values for each output layer.
34 if len(outputs) != len(expecteds):
36 "The number of outputs and expecteds must match. "
37 f
"Got {len(outputs)} outputs and {len(expecteds)} expecteds.")
39 for output, expected
in zip(outputs, expecteds):
40 if output.shape[self.axis] != expected.shape[self.axis]:
42 f
"Output and expected shapes must match along the specified axis {self.axis}. "
43 f
"Got output shape {output.shape} and expected shape {expected.shape}."
46 batch_size = output.shape[self.axis]
47 for b
in range(batch_size):
48 output_idx = np.argmax(output[b])
49 expected_idx = np.argmax(expected[b])
50 if output_idx == expected_idx:
52 self.
total += batch_size