Skip to content

Commit 8f077ab

Browse files
authored
[Fix] Skip autocast_simple_backward_bf16 on platform without avx2_vnni_2 support (#5571) (#5578)
* [Fix] Skip autocast_simple_backward_bf16 on Arrow lake and Lunar lake * fix format
1 parent f6cb15e commit 8f077ab

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tests/gpu/examples/test_autocast.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
import platform
32
import torch
43
from torch.testing._internal.common_utils import TestCase
54

@@ -9,13 +8,6 @@
98
dpcpp_device = torch.device("xpu")
109
checking_atol = 1e-2
1110
checking_rtol = 3e-2
12-
cpu_name = platform.processor()
13-
skipIfNoAvx2Vnni2 = unittest.skipIf(
14-
# Arrow lake and Lunar lake
15-
cpu_name == "Intel64 Family 6 Model 198 Stepping 2, GenuineIntel"
16-
or cpu_name == "Intel64 Family 6 Model 189 Stepping 1, GenuineIntel",
17-
"Skip on Arrow/Lunar lake because they do not support avx2_vnni_2",
18-
)
1911

2012

2113
class TestNet(torch.nn.Module):
@@ -66,7 +58,10 @@ def test_autocast_simple_forward_fp16(self):
6658
print(y_xpu.to("cpu"))
6759
self.assertEqual(y_xpu.dtype, torch.float16)
6860

69-
@skipIfNoAvx2Vnni2
61+
@unittest.skipIf(
62+
not torch.cpu._is_vnni_supported(),
63+
"Test requires CPU with avx2_vnni_2",
64+
)
7065
def test_autocast_simple_backward_bf16(self):
7166
model = TestNet()
7267
x = torch.ones([2, 3, 8, 6], dtype=torch.float)

0 commit comments

Comments
 (0)