Skip to content

Add Dice Loss #6960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d
from .dice_loss import dice_loss
from .diou_loss import distance_box_iou_loss
from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
Expand Down
69 changes: 69 additions & 0 deletions torchvision/ops/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn.functional as F

from ..utils import _log_api_usage_once


def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-8) -> torch.Tensor:
r"""Criterion that computes Sørensen-Dice Coefficient loss.

We compute the Sørensen-Dice Coefficient as follows:

.. math::

\text{Dice}(x, class) = \frac{2 |X \cap Y|}{|X| + |Y|}

Where:
- :math:`X` expects to be the scores of each class.
- :math:`Y` expects to be thess tensor with the class labels.

the loss, is finally computed as:

.. math::

\text{loss}(x, class) = 1 - \text{Dice}(x, class)

Args:
inputs: (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets: (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
eps: (float, optional): Scalar to enforce numerical stabiliy.
reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.

Return:
Tensor: Loss tensor with the reduction option applied.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(dice_loss)

# compute softmax over the classes axis
p = F.softmax(inputs, dim=1)

# compute the actual dice score
dims = (1, 2, 3)
intersection = torch.sum(p * targets, dims)
cardinality = torch.sum(p + targets, dims)

dice_score = 2.0 * intersection / (cardinality + eps)

loss = 1.0 - dice_score

# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)

return loss