Skip to content

Commit fc39ee4

Browse files
committed
change input dimension from (B, H, W, C) to (B, C, H, W)
1 parent f087fc6 commit fc39ee4

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

test/test_ops.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,18 +1775,22 @@ def get_reduction_method(self, reduction):
17751775

17761776
@pytest.mark.parametrize("device", cpu_and_gpu())
17771777
def test_dice_loss(self, device):
1778-
input = torch.tensor([[[0.9409, 0.9220], [0.9524, 0.1094]],
1779-
[[0.6802, 0.7949], [0.9570, 0.1499]],
1780-
[[0.3298, 0.4401], [0.1094, 0.7536]],
1781-
[[0.3340, 0.9895], [0.9563, 0.5045]]], device=device)
1778+
input = torch.tensor([[[0.9409, 0.9524],
1779+
[0.9220, 0.1094]],
1780+
[[0.6802, 0.9570],
1781+
[0.7949, 0.1499]],
1782+
[[0.3298, 0.1094],
1783+
[0.4401, 0.7536]],
1784+
[[0.3340, 0.9563],
1785+
[0.9895, 0.5045]]], device=device)
17821786
labels = torch.tensor([[[0, 1], [1, 0]],
17831787
[[1, 0], [0, 1]],
1784-
[[1, 0], [1, 0]],
1788+
[[1, 1], [0, 0]],
17851789
[[1, 0], [0, 1]]], device=device)
17861790
expected = torch.tensor([0.4028, 0.6101, 0.5916, 0.6347], device=device)
17871791
torch.testing.assert_allclose(ops.dice_loss(input, labels, eps=0), expected)
17881792

1789-
@pytest.mark.parametrize("shape", ((16, 4, 4, 2), (32, 2), (32, 4, 4, 4, 2)))
1793+
@pytest.mark.parametrize("shape", ((16, 2, 4, 4), (16, 4, 4, 4), (32, 2), (32, 2, 4, 4, 4)))
17901794
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
17911795
@pytest.mark.parametrize("device", cpu_and_gpu())
17921796
def test_dice_loss_one(self, shape, reduction, device):
@@ -1800,19 +1804,19 @@ def test_dice_loss_one(self, shape, reduction, device):
18001804

18011805
@pytest.mark.parametrize("device", cpu_and_gpu())
18021806
def test_dice_loss_all_zeros(self, device):
1803-
shape = (16, 4, 4, 2)
1807+
shape = (16, 2, 4, 4)
18041808
input_zeros = torch.zeros(shape, device=device)
1805-
input_zeros[:, :, :, 0] = 1.0
1806-
input_zeros[:, :, :, 1] = 0.0
1809+
input_zeros[:, 0, :, :] = 1.0
1810+
input_zeros[:, 1, :, :] = 0.0
18071811
label_zeros = torch.zeros(shape, device=device)
18081812
label_zeros.copy_(input_zeros)
1809-
input_zeros[:, :, :, 0] = 100.0
1813+
input_zeros[:, 0, :, :] = 100.0
18101814
expected = torch.zeros(16, device=device)
18111815
torch.testing.assert_close(ops.dice_loss(input_zeros, label_zeros), expected)
18121816

18131817
@pytest.mark.parametrize("device", cpu_and_gpu())
18141818
def test_gradcheck(self, device):
1815-
shape = (16, 4, 4, 2)
1819+
shape = (16, 2, 4, 4)
18161820
input_ones = torch.ones(shape, device=device, requires_grad=True)
18171821
label_zeros = torch.zeros(shape, device=device, requires_grad=True)
18181822
assert gradcheck(ops.dice_loss, (input_ones, label_zeros), eps=1e-2, atol=1e-2, raise_exception=True, fast_mode=True)

torchvision/ops/dice_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non
2424
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
2525
2626
Args:
27-
inputs: (Tensor): A float tensor with rank >= 2 and shape (B, N1, .... NK, C)
27+
inputs: (Tensor): A float tensor with rank >= 2 and shape (B, C, N1, .... NK)
2828
where B is the Batch Size and C is the number of classes.
2929
The predictions for each example.
3030
targets: (Tensor): A one-hot tensor with the same shape as inputs.
31-
The first dimension is the batch size and last dimension is the
31+
The first dimension is the batch size and the second dimension is the
3232
number of classes.
3333
eps: (float, optional): Scalar to enforce numerical stability.
3434
reduction (string, optional): ``'none'`` | ``'mean'`` | ``'sum'``
@@ -44,13 +44,13 @@ def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, reduction: str = "non
4444
_log_api_usage_once(dice_loss)
4545

4646
# compute softmax over the classes axis
47-
p = F.softmax(inputs, dim=-1)
47+
p = F.softmax(inputs, dim=1)
4848
p = p.flatten(start_dim=1)
4949

5050
targets = targets.flatten(start_dim=1)
5151

52-
intersection = torch.sum(p * targets, dim=-1)
53-
cardinality = torch.sum(p + targets, dim=-1)
52+
intersection = torch.sum(p * targets, dim=1)
53+
cardinality = torch.sum(p + targets, dim=1)
5454

5555
dice_score = 2.0 * intersection / (cardinality + eps)
5656

0 commit comments

Comments
 (0)