@@ -1775,18 +1775,22 @@ def get_reduction_method(self, reduction):
1775
1775
1776
1776
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1777
1777
def test_dice_loss (self , device ):
1778
- input = torch .tensor ([[[0.9409 , 0.9220 ], [0.9524 , 0.1094 ]],
1779
- [[0.6802 , 0.7949 ], [0.9570 , 0.1499 ]],
1780
- [[0.3298 , 0.4401 ], [0.1094 , 0.7536 ]],
1781
- [[0.3340 , 0.9895 ], [0.9563 , 0.5045 ]]], device = device )
1778
+ input = torch .tensor ([[[0.9409 , 0.9524 ],
1779
+ [0.9220 , 0.1094 ]],
1780
+ [[0.6802 , 0.9570 ],
1781
+ [0.7949 , 0.1499 ]],
1782
+ [[0.3298 , 0.1094 ],
1783
+ [0.4401 , 0.7536 ]],
1784
+ [[0.3340 , 0.9563 ],
1785
+ [0.9895 , 0.5045 ]]], device = device )
1782
1786
labels = torch .tensor ([[[0 , 1 ], [1 , 0 ]],
1783
1787
[[1 , 0 ], [0 , 1 ]],
1784
- [[1 , 0 ], [1 , 0 ]],
1788
+ [[1 , 1 ], [0 , 0 ]],
1785
1789
[[1 , 0 ], [0 , 1 ]]], device = device )
1786
1790
expected = torch .tensor ([0.4028 , 0.6101 , 0.5916 , 0.6347 ], device = device )
1787
1791
torch .testing .assert_allclose (ops .dice_loss (input , labels , eps = 0 ), expected )
1788
1792
1789
- @pytest .mark .parametrize ("shape" , ((16 , 4 , 4 , 2 ), (32 , 2 ), (32 , 4 , 4 , 4 , 2 )))
1793
+ @pytest .mark .parametrize ("shape" , ((16 , 2 , 4 , 4 ), ( 16 , 4 , 4 , 4 ), (32 , 2 ), (32 , 2 , 4 , 4 , 4 )))
1790
1794
@pytest .mark .parametrize ("reduction" , ["none" , "mean" , "sum" ])
1791
1795
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1792
1796
def test_dice_loss_one (self , shape , reduction , device ):
@@ -1800,19 +1804,19 @@ def test_dice_loss_one(self, shape, reduction, device):
1800
1804
1801
1805
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1802
1806
def test_dice_loss_all_zeros (self , device ):
1803
- shape = (16 , 4 , 4 , 2 )
1807
+ shape = (16 , 2 , 4 , 4 )
1804
1808
input_zeros = torch .zeros (shape , device = device )
1805
- input_zeros [:, : , :, 0 ] = 1.0
1806
- input_zeros [:, : , :, 1 ] = 0.0
1809
+ input_zeros [:, 0 , :, : ] = 1.0
1810
+ input_zeros [:, 1 , :, : ] = 0.0
1807
1811
label_zeros = torch .zeros (shape , device = device )
1808
1812
label_zeros .copy_ (input_zeros )
1809
- input_zeros [:, : , :, 0 ] = 100.0
1813
+ input_zeros [:, 0 , :, : ] = 100.0
1810
1814
expected = torch .zeros (16 , device = device )
1811
1815
torch .testing .assert_close (ops .dice_loss (input_zeros , label_zeros ), expected )
1812
1816
1813
1817
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1814
1818
def test_gradcheck (self , device ):
1815
- shape = (16 , 4 , 4 , 2 )
1819
+ shape = (16 , 2 , 4 , 4 )
1816
1820
input_ones = torch .ones (shape , device = device , requires_grad = True )
1817
1821
label_zeros = torch .zeros (shape , device = device , requires_grad = True )
1818
1822
assert gradcheck (ops .dice_loss , (input_ones , label_zeros ), eps = 1e-2 , atol = 1e-2 , raise_exception = True , fast_mode = True )
0 commit comments