23 Update the metric's state based on the outputs and expecteds.
25 outputs (list of np.ndarray): List of model outputs for each output layer.
26 expecteds (list of np.ndarray): List of expected ground truth values for each output layer.
28 if len(outputs) != len(expecteds):
30 "The number of outputs and expecteds must match. "
31 f
"Got {len(outputs)} outputs and {len(expecteds)} expecteds.")
33 for output, expected
in zip(outputs, expecteds):
34 if output.shape[self.
axis] != expected.shape[self.
axis]:
36 f
"Output and expected shapes must match along the specified axis {self.axis}. "
37 f
"Got output shape {output.shape} and expected shape {expected.shape}."
40 batch_size = output.shape[self.
axis]
41 for b
in range(batch_size):
42 output_idx = np.argmax(output[b])
43 expected_idx = np.argmax(expected[b])
44 if output_idx == expected_idx:
46 self.
total += batch_size