-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add Dice Loss #6960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Dice Loss #6960
Changes from all commits
6532e35
248f50a
7dca543
636b7a7
8bf7638
fb8cefd
fcabd6a
d7cafec
f087fc6
fc39ee4
dbb8aaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1765,5 +1765,58 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): | |
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs | ||
|
||
|
||
class TestDiceLoss: | ||
def get_reduction_method(self, reduction): | ||
return {"sum": torch.sum, "mean": torch.mean, "none": None}[reduction] | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_dice_loss(self, device): | ||
input_tensor = torch.tensor( | ||
[ | ||
[[0.9409, 0.9524], [0.9220, 0.1094]], | ||
[[0.6802, 0.9570], [0.7949, 0.1499]], | ||
[[0.3298, 0.1094], [0.4401, 0.7536]], | ||
[[0.3340, 0.9563], [0.9895, 0.5045]], | ||
], | ||
device=device, | ||
) | ||
labels = torch.tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]], [[1, 1], [0, 0]], [[1, 0], [0, 1]]], device=device) | ||
expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device) | ||
torch.testing.assert_allclose(ops.dice_loss(input_tensor, labels, eps=0), expected) | ||
|
||
@pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4))) | ||
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_dice_loss_one(self, shape, reduction, device): | ||
input_ones = torch.ones(shape, device=device) | ||
label_zeros = torch.zeros(shape, device=device) | ||
expected = torch.ones(shape[0], device=device) | ||
reduction_fn = self.get_reduction_method(reduction) | ||
if reduction_fn is not None: | ||
expected = reduction_fn(expected) | ||
torch.testing.assert_close(ops.dice_loss(input_ones, label_zeros, reduction=reduction), expected) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_dice_loss_all_zeros(self, device): | ||
shape = (16, 2, 4, 4) | ||
input_zeros = torch.zeros(shape, device=device) | ||
input_zeros[:, 0, :, :] = 1.0 | ||
input_zeros[:, 1, :, :] = 0.0 | ||
label_zeros = torch.zeros(shape, device=device) | ||
label_zeros.copy_(input_zeros) | ||
input_zeros[:, 0, :, :] = 100.0 | ||
expected = torch.zeros(16, device=device) | ||
torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if gradcheck is needed. Can you provide some reference of why you added this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to follow the test cases for Focal Loss, and I believe it checks for grad although not in a separate function. |
||
def test_gradcheck(self, device): | ||
shape = (16, 2, 4, 4) | ||
input_ones = torch.ones(shape, device=device, requires_grad=True) | ||
label_zeros = torch.zeros(shape, device=device, requires_grad=True) | ||
assert gradcheck( | ||
ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True | ||
) | ||
|
||
pri1311 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ..utils import _log_api_usage_once | ||
|
||
|
||
pri1311 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Implementation adapted from https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss | ||
def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-7) -> torch.Tensor: | ||
r"""Criterion that computes Sørensen-Dice Coefficient loss. | ||
|
||
We compute the Sørensen-Dice Coefficient as follows: | ||
|
||
.. math:: | ||
|
||
\text{Dice\_Loss}(X, Y) = 1 - \frac{2 |X \cap Y|}{|X| + |Y|} | ||
|
||
Where: | ||
- :math:`X` expects to be the scores of each class. | ||
- :math:`Y` expects to be the one hot tensor with the class labels. | ||
|
||
the loss, is finally computed as: | ||
|
||
.. math:: | ||
|
||
\text{loss}(x, class) = 1 - \text{Dice}(x, class) | ||
|
||
Args: | ||
inputs: (Tensor): A float tensor with rank >= 2 and shape (B, num_classes, N1, .... NK) | ||
where B is the Batch Size and num_classes is the number of classes. | ||
The predictions for each example. | ||
targets: (Tensor): A one-hot tensor with the same shape as inputs. | ||
The first dimension is the batch size and the second dimension is the | ||
number of classes. | ||
eps: (float, optional): Scalar to enforce numerical stability. | ||
reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'`` | ||
``'none'``: No reduction will be applied to the output. | ||
``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'``. | ||
|
||
Return: | ||
Tensor: Loss tensor with the reduction option applied. | ||
""" | ||
|
||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
_log_api_usage_once(dice_loss) | ||
|
||
# compute softmax over the classes axis | ||
p = F.softmax(inputs, dim=1) | ||
p = p.flatten(start_dim=1) | ||
|
||
targets = targets.flatten(start_dim=1) | ||
|
||
intersection = torch.sum(p * targets, dim=1) | ||
cardinality = torch.sum(p + targets, dim=1) | ||
|
||
dice_score = 2.0 * intersection / (cardinality + eps) | ||
|
||
loss = 1.0 - dice_score | ||
|
||
# Check reduction option and return loss accordingly | ||
if reduction == "none": | ||
pass | ||
elif reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
else: | ||
raise ValueError( | ||
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" | ||
) | ||
|
||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to write this in a new file. Already
test_ops.py
is a huge file >1.5k lines of code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah quick check shows that all have written tests for losses to this file. Let this stay for now, thoughts @pmeier @NicolasHug ?