From 6532e35c535a72d4a1263d6e8ac91de6f49b9bb6 Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Fri, 18 Nov 2022 19:39:04 +0530 Subject: [PATCH 1/9] add `dice_loss` --- torchvision/ops/__init__.py | 1 + torchvision/ops/dice_loss.py | 74 ++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 torchvision/ops/dice_loss.py diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..d4d62048452 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -26,6 +26,7 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth +from .dice_loss import dice_loss _register_custom_op() diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py new file mode 100644 index 00000000000..f0b639f8127 --- /dev/null +++ b/torchvision/ops/dice_loss.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F + +def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor: + """Criterion that computes Sørensen-Dice Coefficient loss. + + We compute the Sørensen-Dice Coefficient as follows: + + .. math:: + + \text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|} + + Where: + - :math:`X` expects to be the scores of each class. + - :math:`Y` expects to be thess tensor with the class labels. + + the loss, is finally computed as: + + .. math:: + + \text{loss}(x, class) = 1 - \text{Dice}(x, class) + + Args: + inputs: (Tensor): A float tensor of arbitrary shape. + The predictions for each example. + targets: (Tensor): A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + eps: (float, optional): Scalar to enforce numerical stabiliy. + reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + + Return: + Tensor: Loss tensor with the reduction option applied. + """ + if not isinstance(inputs, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(inputs)}") + + if not len(inputs.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {inputs.shape}") + + if not inputs.shape[-2:] == targets.shape[-2:]: + raise ValueError(f"input and target shapes must be the same. Got: {inputs.shape} and {targets.shape}") + + if not inputs.device == targets.device: + raise ValueError(f"input and target must be in the same device. Got: {inputs.device} and {targets.device}") + + # compute softmax over the classes axis + p = F.softmax(inputs, dim=1) + + # compute the actual dice score + dims = (1, 2, 3) + intersection = torch.sum(p * targets, dims) + cardinality = torch.sum(p + targets, dims) + + dice_score = 2.0 * intersection / (cardinality + eps) + + loss = 1.0 - dice_score + + # Check reduction option and return loss accordingly + if reduction == "none": + pass + elif reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + else: + raise ValueError( + f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" + ) + + return loss \ No newline at end of file From 248f50a614040333219fd200db97c05bd2945028 Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Fri, 18 Nov 2022 19:47:04 +0530 Subject: [PATCH 2/9] change formatting --- torchvision/ops/dice_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index f0b639f8127..a8d18677371 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -1,8 +1,9 @@ import torch import torch.nn.functional as F + def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor: - """Criterion that computes Sørensen-Dice Coefficient loss. + r"""Criterion that computes Sørensen-Dice Coefficient loss. We compute the Sørensen-Dice Coefficient as follows: @@ -71,4 +72,4 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" ) - return loss \ No newline at end of file + return loss From 7dca5438e6f1b92e5f8e58455a1d001697dc4e60 Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Fri, 18 Nov 2022 20:06:17 +0530 Subject: [PATCH 3/9] add fixes --- torchvision/ops/__init__.py | 2 +- torchvision/ops/dice_loss.py | 14 ++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index d4d62048452..bac28a6a4b8 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,6 +14,7 @@ ) from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d +from .dice_loss import dice_loss from .diou_loss import distance_box_iou_loss from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork @@ -26,7 +27,6 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth -from .dice_loss import dice_loss _register_custom_op() diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index a8d18677371..657502969f8 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -1,6 +1,8 @@ import torch import torch.nn.functional as F +from ..utils import _log_api_usage_once + def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor: r"""Criterion that computes Sørensen-Dice Coefficient loss. @@ -36,17 +38,9 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non Return: Tensor: Loss tensor with the reduction option applied. """ - if not isinstance(inputs, torch.Tensor): - raise TypeError(f"Input type is not a torch.Tensor. Got {type(inputs)}") - - if not len(inputs.shape) == 4: - raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {inputs.shape}") - - if not inputs.shape[-2:] == targets.shape[-2:]: - raise ValueError(f"input and target shapes must be the same. Got: {inputs.shape} and {targets.shape}") - if not inputs.device == targets.device: - raise ValueError(f"input and target must be in the same device. Got: {inputs.device} and {targets.device}") + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(dice_loss) # compute softmax over the classes axis p = F.softmax(inputs, dim=1) From 8bf7638995e289d44149b16a4bdc4f378ad9355e Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Fri, 16 Dec 2022 15:48:53 +0530 Subject: [PATCH 4/9] change implementation for dice_loss --- torchvision/ops/dice_loss.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index 657502969f8..2890f40005e 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -4,18 +4,18 @@ from ..utils import _log_api_usage_once -def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor: +def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-7) -> torch.Tensor: r"""Criterion that computes Sørensen-Dice Coefficient loss. We compute the Sørensen-Dice Coefficient as follows: .. math:: - \text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|} + \text{Dice\_Loss}(X, Y) = 1 - \frac{2 |X \cap Y|}{|X| + |Y|} Where: - :math:`X` expects to be the scores of each class. - - :math:`Y` expects to be thess tensor with the class labels. + - :math:`Y` expects to be the one hot tensor with the class labels. the loss, is finally computed as: @@ -24,12 +24,13 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non \text{loss}(x, class) = 1 - \text{Dice}(x, class) Args: - inputs: (Tensor): A float tensor of arbitrary shape. + inputs: (Tensor): A float tensor with rank >= 2 and shape (B, N1, .... NK, C) + where B is the Batch Size and C is the number of classes. The predictions for each example. - targets: (Tensor): A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs + targets: (Tensor): A float tensor with the same shape as inputs. Stores the one-hot + labes for each element in inputs. (0 for the negative class and 1 for the positive class). - eps: (float, optional): Scalar to enforce numerical stabiliy. + eps: (float, optional): Scalar to enforce numerical stability. reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` ``'none'``: No reduction will be applied to the output. ``'mean'``: The output will be averaged. @@ -43,12 +44,13 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non _log_api_usage_once(dice_loss) # compute softmax over the classes axis - p = F.softmax(inputs, dim=1) + p = F.softmax(inputs, dim=-1) + p = p.flatten(start_dim=1) - # compute the actual dice score - dims = (1, 2, 3) - intersection = torch.sum(p * targets, dims) - cardinality = torch.sum(p + targets, dims) + targets = targets.flatten(start_dim=1) + + intersection = torch.sum(p * targets, dim=-1) + cardinality = torch.sum(p + targets, dim=-1) dice_score = 2.0 * intersection / (cardinality + eps) From fb8cefdc838c6c7f64603f133900d80670a52edb Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Fri, 16 Dec 2022 16:15:32 +0530 Subject: [PATCH 5/9] change documentation --- torchvision/ops/dice_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index 2890f40005e..bde6df82349 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -27,9 +27,9 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non inputs: (Tensor): A float tensor with rank >= 2 and shape (B, N1, .... NK, C) where B is the Batch Size and C is the number of classes. The predictions for each example. - targets: (Tensor): A float tensor with the same shape as inputs. Stores the one-hot - labes for each element in inputs. - (0 for the negative class and 1 for the positive class). + targets: (Tensor): A one-hot tensor with the same shape as inputs. + The first dimension is the batch size and last dimension is the + number of classes. eps: (float, optional): Scalar to enforce numerical stability. reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` ``'none'``: No reduction will be applied to the output. From fcabd6aaca867565108457fded5e08439aedb23a Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Thu, 22 Dec 2022 10:17:25 +0530 Subject: [PATCH 6/9] add basic tests --- test/test_ops.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 99b58bb93a7..65672f41b4c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1770,5 +1770,53 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs +class TestDiceLoss: + def get_reduction_method(self, reduction): + return { + "sum": torch.sum, + "mean": torch.mean, + "none": None + }[reduction] + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss_one(self, device): + shape = (16, 4, 4, 2) + input_ones = torch.ones(shape, device=device) + label_zeros = torch.zeros(shape, device=device) + expected = torch.ones(16, device=device) + torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros), expected) + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss_all_zeros(self, device): + shape = (16, 4, 4, 2) + input_zeros = torch.zeros(shape, device=device) + input_zeros[:, :, :, 0] = 1.0 + input_zeros[:, :, :, 1] = 0.0 + label_zeros = torch.zeros(shape, device=device) + label_zeros.copy_(input_zeros) + input_zeros[:, :, :, 0] = 100.0 + expected = torch.zeros(16, device=device) + torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected) + + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_reduction_methods(self, reduction, device): + shape = (16, 4, 4, 2) + input_ones = torch.ones(shape, device=device) + label_zeros = torch.zeros(shape, device=device) + expected = torch.ones(16, device=device) + reduction_fn = self.get_reduction_method(reduction) + if reduction_fn is not None: + expected = reduction_fn(expected) + torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) + + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_gradcheck(self, device): + shape = (16, 4, 4, 2) + input_ones = torch.ones(shape, device=device, requires_grad=True) + label_zeros = torch.zeros(shape, device=device, requires_grad=True) + assert gradcheck(ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True) + + if __name__ == "__main__": pytest.main([__file__]) From d7cafec40b2bbb40af85a753bc1dcf11ce87423b Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Mon, 2 Jan 2023 17:25:19 +0530 Subject: [PATCH 7/9] add addtional tests --- test/test_ops.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 65672f41b4c..950196a9fdb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1779,12 +1779,29 @@ def get_reduction_method(self, reduction): }[reduction] @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_dice_loss_one(self, device): - shape = (16, 4, 4, 2) + def test_dice_loss(self, device): + input = torch.tensor([[[0.9409, 0.9220], [0.9524, 0.1094]], + [[0.6802, 0.7949], [0.9570, 0.1499]], + [[0.3298, 0.4401], [0.1094, 0.7536]], + [[0.3340, 0.9895], [0.9563, 0.5045]]], device=device) + labels = torch.tensor([[[0, 1], [1, 0]], + [[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[1, 0], [0, 1]]], device=device) + expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device) + torch.testing.assert_allclose(ops.dice_loss(input, labels, eps=0), expected) + + @pytest.mark.parametrize("shape", ((16, 4, 4, 2), (32, 2), (32, 4, 4, 4, 2))) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_dice_loss_one(self, shape, reduction, device): input_ones = torch.ones(shape, device=device) label_zeros = torch.zeros(shape, device=device) - expected = torch.ones(16, device=device) - torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros), expected) + expected = torch.ones(shape[0], device=device) + reduction_fn = self.get_reduction_method(reduction) + if reduction_fn is not None: + expected = reduction_fn(expected) + torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dice_loss_all_zeros(self, device): @@ -1798,18 +1815,6 @@ def test_dice_loss_all_zeros(self, device): expected = torch.zeros(16, device=device) torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected) - @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_reduction_methods(self, reduction, device): - shape = (16, 4, 4, 2) - input_ones = torch.ones(shape, device=device) - label_zeros = torch.zeros(shape, device=device) - expected = torch.ones(16, device=device) - reduction_fn = self.get_reduction_method(reduction) - if reduction_fn is not None: - expected = reduction_fn(expected) - torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) - @pytest.mark.parametrize("device", cpu_and_gpu()) def test_gradcheck(self, device): shape = (16, 4, 4, 2) From fc39ee4d58eaac5ebb97fdf80aec76bb8bba2f4f Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Wed, 1 Feb 2023 12:34:58 +0530 Subject: [PATCH 8/9] change input dimension from (B, H, W, C) to (B, C, H, W) --- test/test_ops.py | 26 +++++++++++++++----------- torchvision/ops/dice_loss.py | 10 +++++----- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1cb147e9c5e..10fd752d620 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1775,18 +1775,22 @@ def get_reduction_method(self, reduction): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dice_loss(self, device): - input = torch.tensor([[[0.9409, 0.9220], [0.9524, 0.1094]], - [[0.6802, 0.7949], [0.9570, 0.1499]], - [[0.3298, 0.4401], [0.1094, 0.7536]], - [[0.3340, 0.9895], [0.9563, 0.5045]]], device=device) + input = torch.tensor([[[0.9409, 0.9524], + [0.9220, 0.1094]], + [[0.6802, 0.9570], + [0.7949, 0.1499]], + [[0.3298, 0.1094], + [0.4401, 0.7536]], + [[0.3340, 0.9563], + [0.9895, 0.5045]]], device=device) labels = torch.tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]], - [[1, 0], [1, 0]], + [[1, 1], [0, 0]], [[1, 0], [0, 1]]], device=device) expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device) torch.testing.assert_allclose(ops.dice_loss(input, labels, eps=0), expected) - @pytest.mark.parametrize("shape", ((16, 4, 4, 2), (32, 2), (32, 4, 4, 4, 2))) + @pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4))) @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dice_loss_one(self, shape, reduction, device): @@ -1800,19 +1804,19 @@ def test_dice_loss_one(self, shape, reduction, device): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dice_loss_all_zeros(self, device): - shape = (16, 4, 4, 2) + shape = (16, 2, 4, 4) input_zeros = torch.zeros(shape, device=device) - input_zeros[:, :, :, 0] = 1.0 - input_zeros[:, :, :, 1] = 0.0 + input_zeros[:, 0, :, :] = 1.0 + input_zeros[:, 1, :, :] = 0.0 label_zeros = torch.zeros(shape, device=device) label_zeros.copy_(input_zeros) - input_zeros[:, :, :, 0] = 100.0 + input_zeros[:, 0, :, :] = 100.0 expected = torch.zeros(16, device=device) torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected) @pytest.mark.parametrize("device", cpu_and_gpu()) def test_gradcheck(self, device): - shape = (16, 4, 4, 2) + shape = (16, 2, 4, 4) input_ones = torch.ones(shape, device=device, requires_grad=True) label_zeros = torch.zeros(shape, device=device, requires_grad=True) assert gradcheck(ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True) diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index bde6df82349..e27b8aa6ca9 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -24,11 +24,11 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non \text{loss}(x, class) = 1 - \text{Dice}(x, class) Args: - inputs: (Tensor): A float tensor with rank >= 2 and shape (B, N1, .... NK, C) + inputs: (Tensor): A float tensor with rank >= 2 and shape (B, C, N1, .... NK) where B is the Batch Size and C is the number of classes. The predictions for each example. targets: (Tensor): A one-hot tensor with the same shape as inputs. - The first dimension is the batch size and last dimension is the + The first dimension is the batch size and the second dimension is the number of classes. eps: (float, optional): Scalar to enforce numerical stability. reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` @@ -44,13 +44,13 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non _log_api_usage_once(dice_loss) # compute softmax over the classes axis - p = F.softmax(inputs, dim=-1) + p = F.softmax(inputs, dim=1) p = p.flatten(start_dim=1) targets = targets.flatten(start_dim=1) - intersection = torch.sum(p * targets, dim=-1) - cardinality = torch.sum(p + targets, dim=-1) + intersection = torch.sum(p * targets, dim=1) + cardinality = torch.sum(p + targets, dim=1) dice_score = 2.0 * intersection / (cardinality + eps) From dbb8aaa17487455ee7a33afd42140d099a7f77f8 Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Sat, 18 Feb 2023 03:57:41 +0530 Subject: [PATCH 9/9] add ref + fix linting --- test/test_ops.py | 34 +++++++++++++++------------------- torchvision/ops/dice_loss.py | 5 +++-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 10fd752d620..71aa41de2b0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1767,28 +1767,22 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): class TestDiceLoss: def get_reduction_method(self, reduction): - return { - "sum": torch.sum, - "mean": torch.mean, - "none": None - }[reduction] + return {"sum": torch.sum, "mean": torch.mean, "none": None}[reduction] @pytest.mark.parametrize("device", cpu_and_gpu()) def test_dice_loss(self, device): - input = torch.tensor([[[0.9409, 0.9524], - [0.9220, 0.1094]], - [[0.6802, 0.9570], - [0.7949, 0.1499]], - [[0.3298, 0.1094], - [0.4401, 0.7536]], - [[0.3340, 0.9563], - [0.9895, 0.5045]]], device=device) - labels = torch.tensor([[[0, 1], [1, 0]], - [[1, 0], [0, 1]], - [[1, 1], [0, 0]], - [[1, 0], [0, 1]]], device=device) + input_tensor = torch.tensor( + [ + [[0.9409, 0.9524], [0.9220, 0.1094]], + [[0.6802, 0.9570], [0.7949, 0.1499]], + [[0.3298, 0.1094], [0.4401, 0.7536]], + [[0.3340, 0.9563], [0.9895, 0.5045]], + ], + device=device, + ) + labels = torch.tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]], [[1, 1], [0, 0]], [[1, 0], [0, 1]]], device=device) expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device) - torch.testing.assert_allclose(ops.dice_loss(input, labels, eps=0), expected) + torch.testing.assert_allclose(ops.dice_loss(input_tensor, labels, eps=0), expected) @pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4))) @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) @@ -1819,7 +1813,9 @@ def test_gradcheck(self, device): shape = (16, 2, 4, 4) input_ones = torch.ones(shape, device=device, requires_grad=True) label_zeros = torch.zeros(shape, device=device, requires_grad=True) - assert gradcheck(ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True) + assert gradcheck( + ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True + ) if __name__ == "__main__": diff --git a/torchvision/ops/dice_loss.py b/torchvision/ops/dice_loss.py index e27b8aa6ca9..fcc3d8156ea 100644 --- a/torchvision/ops/dice_loss.py +++ b/torchvision/ops/dice_loss.py @@ -4,6 +4,7 @@ from ..utils import _log_api_usage_once +# Implementation adapted from https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-7) -> torch.Tensor: r"""Criterion that computes Sørensen-Dice Coefficient loss. @@ -24,8 +25,8 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non \text{loss}(x, class) = 1 - \text{Dice}(x, class) Args: - inputs: (Tensor): A float tensor with rank >= 2 and shape (B, C, N1, .... NK) - where B is the Batch Size and C is the number of classes. + inputs: (Tensor): A float tensor with rank >= 2 and shape (B, num_classes, N1, .... NK) + where B is the Batch Size and num_classes is the number of classes. The predictions for each example. targets: (Tensor): A one-hot tensor with the same shape as inputs. The first dimension is the batch size and the second dimension is the