Skip to content

Commit 19571bc

Browse files
committed
test: update bench script
1 parent f5693f7 commit 19571bc

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

benchmarks/python/masked_scatter.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import os
3+
import platform
34
import subprocess
45
import time
56
from copy import copy
@@ -17,16 +18,38 @@
1718
if not os.path.isdir(RESULTS_DIR):
1819
os.mkdir(RESULTS_DIR)
1920

20-
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
21-
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
22-
2321
TORCH_DEVICE = torch.device(
2422
"mps"
2523
if torch.backends.mps.is_available()
2624
else ("cuda" if torch.cuda.is_available() else "cpu")
2725
)
2826

2927

28+
def get_device_name():
29+
if TORCH_DEVICE.type == "cuda":
30+
try:
31+
out = subprocess.check_output(
32+
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
33+
stderr=subprocess.DEVNULL,
34+
)
35+
return out.decode("utf-8").splitlines()[0].strip()
36+
except Exception:
37+
return "CUDA_GPU"
38+
if TORCH_DEVICE.type == "mps":
39+
try:
40+
out = subprocess.check_output(
41+
["sysctl", "-n", "machdep.cpu.brand_string"],
42+
stderr=subprocess.DEVNULL,
43+
)
44+
return out.decode("utf-8").strip()
45+
except Exception:
46+
return "Apple_Silicon"
47+
return platform.processor() or platform.machine() or "CPU"
48+
49+
50+
DEVICE_NAME = get_device_name()
51+
52+
3053
N_WARMUP = 5
3154
N_ITER_BENCH = 50
3255
N_ITER_FUNC = 20

0 commit comments

Comments
 (0)