File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
tests/unittests/classification Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -202,6 +202,10 @@ def _reference_sklearn_precision_recall_multiclass(
202202 if preds .ndim == target .ndim + 1 :
203203 preds = torch .argmax (preds , 1 )
204204
205+ valid_labels = list (range (NUM_CLASSES ))
206+ if ignore_index is not None :
207+ valid_labels = [label for label in valid_labels if label != ignore_index ]
208+
205209 if multidim_average == "global" :
206210 preds = preds .numpy ().flatten ()
207211 target = target .numpy ().flatten ()
@@ -210,7 +214,7 @@ def _reference_sklearn_precision_recall_multiclass(
210214 target ,
211215 preds ,
212216 average = average ,
213- labels = list ( range ( NUM_CLASSES )) if average is None else None ,
217+ labels = valid_labels if average in ( "macro" , "weighted" ) else None ,
214218 zero_division = zero_division ,
215219 )
216220
@@ -235,7 +239,7 @@ def _reference_sklearn_precision_recall_multiclass(
235239 true ,
236240 pred ,
237241 average = average ,
238- labels = list ( range ( NUM_CLASSES )) if average is None else None ,
242+ labels = valid_labels if average in ( "macro" , "weighted" ) else None ,
239243 zero_division = zero_division ,
240244 )
241245 res .append (0.0 if np .isnan (r ).any () else r )
You can’t perform that action at this time.
0 commit comments