@@ -1767,28 +1767,22 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
1767
1767
1768
1768
class TestDiceLoss :
1769
1769
def get_reduction_method (self , reduction ):
1770
- return {
1771
- "sum" : torch .sum ,
1772
- "mean" : torch .mean ,
1773
- "none" : None
1774
- }[reduction ]
1770
+ return {"sum" : torch .sum , "mean" : torch .mean , "none" : None }[reduction ]
1775
1771
1776
1772
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1777
1773
def test_dice_loss (self , device ):
1778
- input = torch .tensor ([[[0.9409 , 0.9524 ],
1779
- [0.9220 , 0.1094 ]],
1780
- [[0.6802 , 0.9570 ],
1781
- [0.7949 , 0.1499 ]],
1782
- [[0.3298 , 0.1094 ],
1783
- [0.4401 , 0.7536 ]],
1784
- [[0.3340 , 0.9563 ],
1785
- [0.9895 , 0.5045 ]]], device = device )
1786
- labels = torch .tensor ([[[0 , 1 ], [1 , 0 ]],
1787
- [[1 , 0 ], [0 , 1 ]],
1788
- [[1 , 1 ], [0 , 0 ]],
1789
- [[1 , 0 ], [0 , 1 ]]], device = device )
1774
+ input_tensor = torch .tensor (
1775
+ [
1776
+ [[0.9409 , 0.9524 ], [0.9220 , 0.1094 ]],
1777
+ [[0.6802 , 0.9570 ], [0.7949 , 0.1499 ]],
1778
+ [[0.3298 , 0.1094 ], [0.4401 , 0.7536 ]],
1779
+ [[0.3340 , 0.9563 ], [0.9895 , 0.5045 ]],
1780
+ ],
1781
+ device = device ,
1782
+ )
1783
+ labels = torch .tensor ([[[0 , 1 ], [1 , 0 ]], [[1 , 0 ], [0 , 1 ]], [[1 , 1 ], [0 , 0 ]], [[1 , 0 ], [0 , 1 ]]], device = device )
1790
1784
expected = torch .tensor ([0.4028 , 0.6101 , 0.5916 , 0.6347 ], device = device )
1791
- torch .testing .assert_allclose (ops .dice_loss (input , labels , eps = 0 ), expected )
1785
+ torch .testing .assert_allclose (ops .dice_loss (input_tensor , labels , eps = 0 ), expected )
1792
1786
1793
1787
@pytest .mark .parametrize ("shape" , ((16 , 2 , 4 , 4 ), (16 , 4 , 4 , 4 ), (32 , 2 ), (32 , 2 , 4 , 4 , 4 )))
1794
1788
@pytest .mark .parametrize ("reduction" , ["none" , "mean" , "sum" ])
@@ -1819,7 +1813,9 @@ def test_gradcheck(self, device):
1819
1813
shape = (16 , 2 , 4 , 4 )
1820
1814
input_ones = torch .ones (shape , device = device , requires_grad = True )
1821
1815
label_zeros = torch .zeros (shape , device = device , requires_grad = True )
1822
- assert gradcheck (ops .dice_loss , (input_ones , label_zeros ), eps = 1e-2 , atol = 1e-2 , raise_exception = True , fast_mode = True )
1816
+ assert gradcheck (
1817
+ ops .dice_loss , (input_ones , label_zeros ), eps = 1e-2 , atol = 1e-2 , raise_exception = True , fast_mode = True
1818
+ )
1823
1819
1824
1820
1825
1821
if __name__ == "__main__" :
0 commit comments