Skip to content

Commit f2d1620

Browse files
[ROCm][CI] Fix flaky GPTQ compile correctness test (vllm-project#38161)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent 37a8300 commit f2d1620

2 files changed

Lines changed: 41 additions & 34 deletions

File tree

tests/compile/fullgraph/test_basic_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_compile_correctness(
137137
all_args.append(
138138
final_args + [f"-cc.mode={mode.name}", "-cc.backend=inductor"]
139139
)
140+
all_envs.append({})
140141

141142
# inductor will change the output, so we only compare if the output
142143
# is close, not exactly the same.
@@ -157,6 +158,5 @@ def test_compile_correctness(
157158
]:
158159
all_args.append(final_args + [f"-cc.mode={mode.name}", "-cc.backend=eager"])
159160
all_envs.append({})
160-
all_envs.append({})
161161

162-
compare_all_settings(model, all_args * 3, all_envs, method=method)
162+
compare_all_settings(model, all_args, all_envs, method=method)

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,40 +1348,47 @@ def initialize_single_dummy_weight(
13481348
high: float = 1e-3,
13491349
seed: int = 1234,
13501350
) -> None:
1351-
if torch.is_floating_point(param):
1352-
if current_platform.is_tpu():
1353-
generator = torch.Generator(device="cpu")
1354-
generator.manual_seed(seed)
1355-
# Note: The param.uniform_ function cannot be used in this
1356-
# context because it demands more TPU HBM than directly copying
1357-
# from a CPU tensor.
1358-
# Note: We avoid using torch.rank_like as it doesn't currently
1359-
# support the generator argument.
1360-
param.copy_(
1361-
(high - low)
1362-
* torch.rand(
1363-
param.shape,
1364-
generator=generator,
1365-
dtype=param.dtype,
1366-
layout=param.layout,
1367-
requires_grad=param.requires_grad,
1368-
device="cpu",
1369-
)
1370-
+ low
1371-
)
1372-
torch._sync(param)
1373-
return
1351+
if not torch.is_floating_point(param):
1352+
if current_platform.is_rocm():
1353+
# On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left
1354+
# as torch.empty() by default, giving non-deterministic values
1355+
# across processes. Zero them for reproducibility.
1356+
param.zero_()
1357+
return
13741358

1375-
generator = torch.Generator(device=param.data.device)
1359+
if current_platform.is_tpu():
1360+
generator = torch.Generator(device="cpu")
13761361
generator.manual_seed(seed)
1377-
if torch.finfo(param.data.dtype).bits < 16:
1378-
# uniform_ doesn't support < 16-bit datatypes (FP8)
1379-
dtype = param.data.dtype
1380-
tmp_param = param.data.to(torch.float16)
1381-
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
1382-
param.data.copy_(tmp_param)
1383-
else:
1384-
param.uniform_(low, high, generator=generator)
1362+
# Note: The param.uniform_ function cannot be used in this
1363+
# context because it demands more TPU HBM than directly copying
1364+
# from a CPU tensor.
1365+
# Note: We avoid using torch.rank_like as it doesn't currently
1366+
# support the generator argument.
1367+
param.copy_(
1368+
(high - low)
1369+
* torch.rand(
1370+
param.shape,
1371+
generator=generator,
1372+
dtype=param.dtype,
1373+
layout=param.layout,
1374+
requires_grad=param.requires_grad,
1375+
device="cpu",
1376+
)
1377+
+ low
1378+
)
1379+
torch._sync(param)
1380+
return
1381+
1382+
generator = torch.Generator(device=param.data.device)
1383+
generator.manual_seed(seed)
1384+
if torch.finfo(param.data.dtype).bits < 16:
1385+
# uniform_ doesn't support < 16-bit datatypes (FP8)
1386+
dtype = param.data.dtype
1387+
tmp_param = param.data.to(torch.float16)
1388+
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
1389+
param.data.copy_(tmp_param)
1390+
else:
1391+
param.uniform_(low, high, generator=generator)
13851392

13861393

13871394
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:

0 commit comments

Comments
 (0)