diff --git a/test/test_ops.py b/test/test_ops.py index eb2e31c9bcf..71aa41de2b0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1765,5 +1765,58 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs +class TestDiceLoss: + def get_reduction_method(self, reduction): + return {"sum": torch.sum, "mean": torch.mean, "none": None}[reduction] + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss(self, device): + input_tensor = torch.tensor( + [ + [[0.9409, 0.9524], [0.9220, 0.1094]], + [[0.6802, 0.9570], [0.7949, 0.1499]], + [[0.3298, 0.1094], [0.4401, 0.7536]], + [[0.3340, 0.9563], [0.9895, 0.5045]], + ], + device=device, + ) + labels = torch.tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]], [[1, 1], [0, 0]], [[1, 0], [0, 1]]], device=device) + expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device) + torch.testing.assert_allclose(ops.dice_loss(input_tensor, labels, eps=0), expected) + + @pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4))) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss_one(self, shape, reduction, device): + input_ones = torch.ones(shape, device=device) + label_zeros = torch.zeros(shape, device=device) + expected = torch.ones(shape[0], device=device) + reduction_fn = self.get_reduction_method(reduction) + if reduction_fn is not None: + expected = reduction_fn(expected) + torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss_all_zeros(self, device): + shape = (16, 2, 4, 4) + input_zeros = torch.zeros(shape, device=device) + input_zeros[:, 0, :, :] = 1.0 + input_zeros[:, 1, :, :] = 0.0 + label_zeros = torch.zeros(shape, device=device) + label_zeros.copy_(input_zeros) + input_zeros[:, 0, :, :] = 100.0 + expected = torch.zeros(16, device=device) + torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected) + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_gradcheck(self, device): + shape = (16, 2, 4, 4) + input_ones = torch.ones(shape, device=device, requires_grad=True) + label_zeros = torch.zeros(shape, device=device, requires_grad=True) + assert gradcheck( + ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..bac28a6a4b8 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,6 +14,7 @@ ) from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d +from .dice_loss import dice_loss from .diou_loss import distance_box_iou_loss from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py new file mode 100644 index 00000000000..fcc3d8156ea --- /dev/null +++ b/torchvision/ops/dice_loss.py @@ -0,0 +1,72 @@ +import torch +import torch.nn.functional as F + +from ..utils import _log_api_usage_once + + +# Implementation adapted from https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss +def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-7) -> torch.Tensor: + r"""Criterion that computes Sørensen-Dice Coefficient loss. + + We compute the Sørensen-Dice Coefficient as follows: + + .. math:: + + \text{Dice\_Loss}(X, Y) = 1 - \frac{2 |X \cap Y|}{|X| + |Y|} + + Where: + - :math:`X` expects to be the scores of each class. + - :math:`Y` expects to be the one hot tensor with the class labels. + + the loss, is finally computed as: + + .. math:: + + \text{loss}(x, class) = 1 - \text{Dice}(x, class) + + Args: + inputs: (Tensor): A float tensor with rank >= 2 and shape (B, num_classes, N1, .... NK) + where B is the Batch Size and num_classes is the number of classes. + The predictions for each example. + targets: (Tensor): A one-hot tensor with the same shape as inputs. + The first dimension is the batch size and the second dimension is the + number of classes. + eps: (float, optional): Scalar to enforce numerical stability. + reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + + Return: + Tensor: Loss tensor with the reduction option applied. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(dice_loss) + + # compute softmax over the classes axis + p = F.softmax(inputs, dim=1) + p = p.flatten(start_dim=1) + + targets = targets.flatten(start_dim=1) + + intersection = torch.sum(p * targets, dim=1) + cardinality = torch.sum(p + targets, dim=1) + + dice_score = 2.0 * intersection / (cardinality + eps) + + loss = 1.0 - dice_score + + # Check reduction option and return loss accordingly + if reduction == "none": + pass + elif reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + else: + raise ValueError( + f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" + ) + + return loss