Skip to content

Commit b79bbb6

Browse files
committed
fix: Ensure consistent default values('micro') for average argument in classification metrics
1 parent 8b47970 commit b79bbb6

File tree

8 files changed

+14
-14
lines changed

8 files changed

+14
-14
lines changed

docs/source/pages/overview.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ of metrics e.g. computation of confidence intervals by resampling of input data.
453453
.. testoutput::
454454
:options: +NORMALIZE_WHITESPACE
455455

456-
{'mean': tensor(0.1333), 'std': tensor(0.1554)}
456+
{'mean': tensor(0.1069), 'std': tensor(0.1180)}
457457

458458
You can see all implemented wrappers under the wrapper section of the API docs.
459459

src/torchmetrics/classification/accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class MulticlassAccuracy(MulticlassStatScores):
214214
>>> preds = tensor([2, 1, 0, 1])
215215
>>> metric = MulticlassAccuracy(num_classes=3)
216216
>>> metric(preds, target)
217-
tensor(0.8333)
217+
tensor(0.7500)
218218
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
219219
>>> mca(preds, target)
220220
tensor([0.5000, 1.0000, 1.0000])
@@ -228,7 +228,7 @@ class MulticlassAccuracy(MulticlassStatScores):
228228
... [0.05, 0.82, 0.13]])
229229
>>> metric = MulticlassAccuracy(num_classes=3)
230230
>>> metric(preds, target)
231-
tensor(0.8333)
231+
tensor(0.7500)
232232
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
233233
>>> mca(preds, target)
234234
tensor([0.5000, 1.0000, 1.0000])

src/torchmetrics/classification/hamming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
224224
>>> preds = tensor([2, 1, 0, 1])
225225
>>> metric = MulticlassHammingDistance(num_classes=3)
226226
>>> metric(preds, target)
227-
tensor(0.1667)
227+
tensor(0.2500)
228228
>>> mchd = MulticlassHammingDistance(num_classes=3, average=None)
229229
>>> mchd(preds, target)
230230
tensor([0.5000, 0.0000, 0.0000])

src/torchmetrics/classification/negative_predictive_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class MulticlassNegativePredictiveValue(MulticlassStatScores):
220220
>>> preds = tensor([2, 1, 0, 1])
221221
>>> metric = MulticlassNegativePredictiveValue(num_classes=3)
222222
>>> metric(preds, target)
223-
tensor(0.8889)
223+
tensor(0.8750)
224224
>>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None)
225225
>>> metric(preds, target)
226226
tensor([0.6667, 1.0000, 1.0000])
@@ -371,7 +371,7 @@ class MultilabelNegativePredictiveValue(MultilabelStatScores):
371371
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
372372
>>> metric = MultilabelNegativePredictiveValue(num_labels=3)
373373
>>> metric(preds, target)
374-
tensor(0.5000)
374+
tensor(0.6667)
375375
>>> mls = MultilabelNegativePredictiveValue(num_labels=3, average=None)
376376
>>> mls(preds, target)
377377
tensor([1.0000, 0.5000, 0.0000])

src/torchmetrics/classification/precision_recall.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class MulticlassPrecision(MulticlassStatScores):
237237
>>> preds = tensor([2, 1, 0, 1])
238238
>>> metric = MulticlassPrecision(num_classes=3)
239239
>>> metric(preds, target)
240-
tensor(0.8333)
240+
tensor(0.7500)
241241
>>> mcp = MulticlassPrecision(num_classes=3, average=None)
242242
>>> mcp(preds, target)
243243
tensor([1.0000, 0.5000, 1.0000])
@@ -402,7 +402,7 @@ class MultilabelPrecision(MultilabelStatScores):
402402
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
403403
>>> metric = MultilabelPrecision(num_labels=3)
404404
>>> metric(preds, target)
405-
tensor(0.5000)
405+
tensor(0.6667)
406406
>>> mlp = MultilabelPrecision(num_labels=3, average=None)
407407
>>> mlp(preds, target)
408408
tensor([1.0000, 0.0000, 0.5000])
@@ -696,7 +696,7 @@ class MulticlassRecall(MulticlassStatScores):
696696
>>> preds = tensor([2, 1, 0, 1])
697697
>>> metric = MulticlassRecall(num_classes=3)
698698
>>> metric(preds, target)
699-
tensor(0.8333)
699+
tensor(0.7500)
700700
>>> mcr = MulticlassRecall(num_classes=3, average=None)
701701
>>> mcr(preds, target)
702702
tensor([0.5000, 1.0000, 1.0000])

src/torchmetrics/classification/specificity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class MulticlassSpecificity(MulticlassStatScores):
214214
>>> preds = tensor([2, 1, 0, 1])
215215
>>> metric = MulticlassSpecificity(num_classes=3)
216216
>>> metric(preds, target)
217-
tensor(0.8889)
217+
tensor(0.8750)
218218
>>> mcs = MulticlassSpecificity(num_classes=3, average=None)
219219
>>> mcs(preds, target)
220220
tensor([1.0000, 0.6667, 1.0000])

src/torchmetrics/classification/stat_scores.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
self,
310310
num_classes: int,
311311
top_k: int = 1,
312-
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
312+
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
313313
multidim_average: Literal["global", "samplewise"] = "global",
314314
ignore_index: Optional[int] = None,
315315
validate_args: bool = True,
@@ -461,7 +461,7 @@ def __init__(
461461
self,
462462
num_labels: int,
463463
threshold: float = 0.5,
464-
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
464+
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
465465
multidim_average: Literal["global", "samplewise"] = "global",
466466
ignore_index: Optional[int] = None,
467467
validate_args: bool = True,

src/torchmetrics/functional/classification/accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def multiclass_accuracy(
167167
preds: Tensor,
168168
target: Tensor,
169169
num_classes: int,
170-
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
170+
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
171171
top_k: int = 1,
172172
multidim_average: Literal["global", "samplewise"] = "global",
173173
ignore_index: Optional[int] = None,
@@ -276,7 +276,7 @@ def multilabel_accuracy(
276276
target: Tensor,
277277
num_labels: int,
278278
threshold: float = 0.5,
279-
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
279+
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
280280
multidim_average: Literal["global", "samplewise"] = "global",
281281
ignore_index: Optional[int] = None,
282282
validate_args: bool = True,

0 commit comments

Comments
 (0)