@@ -1348,40 +1348,47 @@ def initialize_single_dummy_weight(
13481348 high : float = 1e-3 ,
13491349 seed : int = 1234 ,
13501350) -> None :
1351- if torch .is_floating_point (param ):
1352- if current_platform .is_tpu ():
1353- generator = torch .Generator (device = "cpu" )
1354- generator .manual_seed (seed )
1355- # Note: The param.uniform_ function cannot be used in this
1356- # context because it demands more TPU HBM than directly copying
1357- # from a CPU tensor.
1358- # Note: We avoid using torch.rank_like as it doesn't currently
1359- # support the generator argument.
1360- param .copy_ (
1361- (high - low )
1362- * torch .rand (
1363- param .shape ,
1364- generator = generator ,
1365- dtype = param .dtype ,
1366- layout = param .layout ,
1367- requires_grad = param .requires_grad ,
1368- device = "cpu" ,
1369- )
1370- + low
1371- )
1372- torch ._sync (param )
1373- return
1351+ if not torch .is_floating_point (param ):
1352+ if current_platform .is_rocm ():
1353+ # On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left
1354+ # as torch.empty() by default, giving non-deterministic values
1355+ # across processes. Zero them for reproducibility.
1356+ param .zero_ ()
1357+ return
13741358
1375- generator = torch .Generator (device = param .data .device )
1359+ if current_platform .is_tpu ():
1360+ generator = torch .Generator (device = "cpu" )
13761361 generator .manual_seed (seed )
1377- if torch .finfo (param .data .dtype ).bits < 16 :
1378- # uniform_ doesn't support < 16-bit datatypes (FP8)
1379- dtype = param .data .dtype
1380- tmp_param = param .data .to (torch .float16 )
1381- tmp_param = tmp_param .uniform_ (low , high , generator = generator ).to (dtype )
1382- param .data .copy_ (tmp_param )
1383- else :
1384- param .uniform_ (low , high , generator = generator )
1362+ # Note: The param.uniform_ function cannot be used in this
1363+ # context because it demands more TPU HBM than directly copying
1364+ # from a CPU tensor.
1365+ # Note: We avoid using torch.rank_like as it doesn't currently
1366+ # support the generator argument.
1367+ param .copy_ (
1368+ (high - low )
1369+ * torch .rand (
1370+ param .shape ,
1371+ generator = generator ,
1372+ dtype = param .dtype ,
1373+ layout = param .layout ,
1374+ requires_grad = param .requires_grad ,
1375+ device = "cpu" ,
1376+ )
1377+ + low
1378+ )
1379+ torch ._sync (param )
1380+ return
1381+
1382+ generator = torch .Generator (device = param .data .device )
1383+ generator .manual_seed (seed )
1384+ if torch .finfo (param .data .dtype ).bits < 16 :
1385+ # uniform_ doesn't support < 16-bit datatypes (FP8)
1386+ dtype = param .data .dtype
1387+ tmp_param = param .data .to (torch .float16 )
1388+ tmp_param = tmp_param .uniform_ (low , high , generator = generator ).to (dtype )
1389+ param .data .copy_ (tmp_param )
1390+ else :
1391+ param .uniform_ (low , high , generator = generator )
13851392
13861393
13871394def maybe_remap_kv_scale_name (name : str , params_dict : dict ) -> str | None :
0 commit comments