diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 905dcc2c251..b09a9e41a6f 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +import torch from torch import Tensor from typing_extensions import Literal @@ -75,16 +76,26 @@ def _accuracy_reduce( """ if average == "binary": return _safe_divide(tp + tn, tp + tn + fp + fn) + + # Calculate base score + score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn) + + # For top_k > 1, always use the adjust_weights function which properly handles top_k + if top_k > 1: + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k) + + # For top_k=1, continue with the original logic if average == "micro": - tp = tp.sum(dim=0 if multidim_average == "global" else 1) - fn = fn.sum(dim=0 if multidim_average == "global" else 1) + # Apply sum before returning for micro averaging + tp_sum = tp.sum(dim=0 if multidim_average == "global" else 1) + fn_sum = fn.sum(dim=0 if multidim_average == "global" else 1) if multilabel: - fp = fp.sum(dim=0 if multidim_average == "global" else 1) - tn = tn.sum(dim=0 if multidim_average == "global" else 1) - return _safe_divide(tp + tn, tp + tn + fp + fn) - return _safe_divide(tp, tp + fn) + fp_sum = fp.sum(dim=0 if multidim_average == "global" else 1) + tn_sum = tn.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tp_sum + tn_sum, tp_sum + tn_sum + fp_sum + fn_sum) + return _safe_divide(tp_sum, tp_sum + fn_sum) - score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn) + # For other averaging methods, apply the adjustment return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k) @@ -264,6 +275,44 @@ def multiclass_accuracy( if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + + if top_k > 1 and average == "micro" and preds.ndim == target.ndim + 1: + if preds.ndim == target.ndim: + num_classes = num_classes or (target.max().int().item() + 1) + preds = torch.nn.functional.one_hot(preds, num_classes).to(preds.dtype) + preds = preds.transpose(1, -1) + + if multidim_average == "global": + flat_shape = preds.shape[:2] + (-1,) + flat_preds = preds.reshape(flat_shape) + flat_target = target.reshape(target.shape[0], -1) + else: + flat_shape = preds.shape[:2] + (-1,) + flat_preds = preds.reshape(flat_shape) + flat_target = target.reshape(target.shape[0], -1) + + batch_size = flat_target.shape[0] + num_samples = flat_target.shape[1] + + if ignore_index is not None: + valid_mask = flat_target != ignore_index + else: + valid_mask = torch.ones_like(flat_target, dtype=torch.bool) + + correct_list = [] + for i in range(batch_size): + for j in range(num_samples): + if not valid_mask[i, j]: + continue + sample_preds = flat_preds[i, :, j] + sample_target = flat_target[i, j] + _, top_indices = torch.topk(sample_preds, min(top_k, sample_preds.size(0)), dim=0) + correct_list.append(torch.any(top_indices == sample_target).int()) + + if correct_list: + return torch.stack(correct_list).float().mean() + return torch.tensor(0.0, device=preds.device) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, num_classes or 1, top_k, average, multidim_average, ignore_index diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index cea314fb0eb..6bbe70a1a57 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -381,15 +381,24 @@ def _multiclass_stat_scores_update( ) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute the statistics. - - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and - target into one hot format. + - If ``multidim_average`` is equal to samplewise or ``top_k`` is greater than 1, we transform both preds and + target into one hot format to properly handle top-k predictions. - Else we calculate statistics by first calculating the confusion matrix and afterwards deriving the statistics from that - Remove all datapoints that should be ignored. Depending on if ``ignore_index`` is in the set of labels or outside we have do use different augmentation strategies when one hot encoding. + Notes: + - For top_k > 1, we always use the one-hot encoding path regardless of the averaging method + to ensure top-k logic is properly applied in all cases, including micro averaging. + """ - if multidim_average == "samplewise" or top_k != 1: + # Modified condition to always use one-hot path when top_k > 1, regardless of average method + if multidim_average == "samplewise" or top_k > 1 or (preds.ndim == target.ndim + 1 and average == "micro"): + # Always use one-hot encoding for: + # 1. samplewise averaging + # 2. top_k > 1 + # 3. when inputs have different dimensions (probably logits vs. class indices) and micro averaging ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None if ignore_index is not None and not ignore_in: preds = preds.clone() @@ -400,9 +409,11 @@ def _multiclass_stat_scores_update( preds[idx] = num_classes if top_k > 1: + # For top_k > 1, we need to get the top-k predictions in one-hot format preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k) else: + # Otherwise just one-hot encode the class indices preds_oh = torch.nn.functional.one_hot( preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes )