Skip to content

Commit dbb8aaa

Browse files
committed
add ref + fix linting
1 parent fc39ee4 commit dbb8aaa

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

test/test_ops.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,28 +1767,22 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
17671767

17681768
class TestDiceLoss:
17691769
def get_reduction_method(self, reduction):
1770-
return {
1771-
"sum": torch.sum,
1772-
"mean": torch.mean,
1773-
"none": None
1774-
}[reduction]
1770+
return {"sum": torch.sum, "mean": torch.mean, "none": None}[reduction]
17751771

17761772
@pytest.mark.parametrize("device", cpu_and_gpu())
17771773
def test_dice_loss(self, device):
1778-
input = torch.tensor([[[0.9409, 0.9524],
1779-
[0.9220, 0.1094]],
1780-
[[0.6802, 0.9570],
1781-
[0.7949, 0.1499]],
1782-
[[0.3298, 0.1094],
1783-
[0.4401, 0.7536]],
1784-
[[0.3340, 0.9563],
1785-
[0.9895, 0.5045]]], device=device)
1786-
labels = torch.tensor([[[0, 1], [1, 0]],
1787-
[[1, 0], [0, 1]],
1788-
[[1, 1], [0, 0]],
1789-
[[1, 0], [0, 1]]], device=device)
1774+
input_tensor = torch.tensor(
1775+
[
1776+
[[0.9409, 0.9524], [0.9220, 0.1094]],
1777+
[[0.6802, 0.9570], [0.7949, 0.1499]],
1778+
[[0.3298, 0.1094], [0.4401, 0.7536]],
1779+
[[0.3340, 0.9563], [0.9895, 0.5045]],
1780+
],
1781+
device=device,
1782+
)
1783+
labels = torch.tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]], [[1, 1], [0, 0]], [[1, 0], [0, 1]]], device=device)
17901784
expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device)
1791-
torch.testing.assert_allclose(ops.dice_loss(input, labels, eps=0), expected)
1785+
torch.testing.assert_allclose(ops.dice_loss(input_tensor, labels, eps=0), expected)
17921786

17931787
@pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4)))
17941788
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@@ -1819,7 +1813,9 @@ def test_gradcheck(self, device):
18191813
shape = (16, 2, 4, 4)
18201814
input_ones = torch.ones(shape, device=device, requires_grad=True)
18211815
label_zeros = torch.zeros(shape, device=device, requires_grad=True)
1822-
assert gradcheck(ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True)
1816+
assert gradcheck(
1817+
ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True
1818+
)
18231819

18241820

18251821
if __name__ == "__main__":

torchvision/ops/dice_loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..utils import _log_api_usage_once
55

66

7+
# Implementation adapted from https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss
78
def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "none", eps: float = 1e-7) -> torch.Tensor:
89
r"""Criterion that computes Sørensen-Dice Coefficient loss.
910
@@ -24,8 +25,8 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non
2425
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
2526
2627
Args:
27-
inputs: (Tensor): A float tensor with rank >= 2 and shape (B, C, N1, .... NK)
28-
where B is the Batch Size and C is the number of classes.
28+
inputs: (Tensor): A float tensor with rank >= 2 and shape (B, num_classes, N1, .... NK)
29+
where B is the Batch Size and num_classes is the number of classes.
2930
The predictions for each example.
3031
targets: (Tensor): A one-hot tensor with the same shape as inputs.
3132
The first dimension is the batch size and the second dimension is the

0 commit comments

Comments
 (0)