Skip to content

Commit 259c4bd

Browse files
committed
fix:Reference Metric in multiclass pecision recall unittests provides wrong answer when ignore_index is specified
1 parent 2365437 commit 259c4bd

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/unittests/classification/test_precision_recall.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def _reference_sklearn_precision_recall_multiclass(
202202
if preds.ndim == target.ndim + 1:
203203
preds = torch.argmax(preds, 1)
204204

205+
valid_labels = list(range(NUM_CLASSES))
206+
if ignore_index is not None:
207+
valid_labels = [label for label in valid_labels if label != ignore_index]
208+
205209
if multidim_average == "global":
206210
preds = preds.numpy().flatten()
207211
target = target.numpy().flatten()
@@ -210,7 +214,7 @@ def _reference_sklearn_precision_recall_multiclass(
210214
target,
211215
preds,
212216
average=average,
213-
labels=list(range(NUM_CLASSES)) if average is None else None,
217+
labels=valid_labels if average in ("macro", "weighted") else None,
214218
zero_division=zero_division,
215219
)
216220

@@ -235,7 +239,7 @@ def _reference_sklearn_precision_recall_multiclass(
235239
true,
236240
pred,
237241
average=average,
238-
labels=list(range(NUM_CLASSES)) if average is None else None,
242+
labels=valid_labels if average in ("macro", "weighted") else None,
239243
zero_division=zero_division,
240244
)
241245
res.append(0.0 if np.isnan(r).any() else r)

0 commit comments

Comments
 (0)