Skip to content

Commit 5ee179c

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix flaky test due to input_jkt.weight dtype (#2763)
Summary: Pull Request resolved: #2763 # context * The [test_model_parallel_nccl](https://fb.workplace.com/groups/970281557043698/posts/1863456557726189/?comment_id=1867254224013089) has been reported to be flaky: [paste](https://www.internalfb.com/intern/everpaste/?color=0&handle=GJBrgxaEWkfR-ycEAP_fNV5sl_l1br0LAAAz) * after an in-depth investigation, the root cause is that when the dtype of the generated input KJT._weights is always `torch.float` (i.e., `torch.float32`), but the test embedding table's dtype could be `torch.FP16`. # changes * convert the input_kjt._weights.dtype to be consistent with `EmbeddingBag.weight.dtype` in EBC (unsharded) Reviewed By: dstaay-fb Differential Revision: D70126859 fbshipit-source-id: 52fc46ced5a3119f168dc4e41eff949ed4a9ec66
1 parent ac739f4 commit 5ee179c

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,7 @@ def _validate_pooling_factor(
223223
else:
224224
raise ValueError(f"For IdList features, unknown input type {input_type}")
225225

226-
for idx in range(len(idscore_ind_ranges)):
227-
ind_range = idscore_ind_ranges[idx]
226+
for idx, ind_range in enumerate(idscore_ind_ranges):
228227
lengths_ = torch.abs(
229228
torch.randn(batch_size * world_size, device=device)
230229
+ (

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@
5959
ShardingPlan,
6060
ShardingType,
6161
)
62-
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
62+
from torchrec.modules.embedding_configs import (
63+
BaseEmbeddingConfig,
64+
DataType,
65+
EmbeddingBagConfig,
66+
)
6367
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
6468
from torchrec.optim.optimizers import in_backward_optimizer_filter
6569

@@ -520,9 +524,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
520524
)
521525

522526
# Compare predictions of sharded vs unsharded models.
523-
if qcomms_config is None:
524-
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
525-
else:
527+
if qcomms_config is not None:
526528
# With quantized comms, we can relax constraints a bit
527529
rtol = 0.003
528530
if CommType.FP8 in [
@@ -534,6 +536,18 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
534536
torch.testing.assert_close(
535537
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
536538
)
539+
elif (
540+
weighted_tables is not None
541+
and weighted_tables[0].data_type == DataType.FP16
542+
): # https://www.internalfb.com/intern/diffing/?paste_number=1740410921
543+
torch.testing.assert_close(
544+
global_pred,
545+
torch.cat(all_local_pred),
546+
atol=1e-4, # relaxed atol due to FP16 in weights
547+
rtol=1e-4, # relaxed rtol due to FP16 in weights
548+
)
549+
else:
550+
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
537551

538552

539553
def create_device_mesh_for_2D(

torchrec/modules/embedding_modules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def forward(
248248
res = embedding_bag(
249249
input=f.values(),
250250
offsets=f.offsets(),
251-
per_sample_weights=f.weights() if self._is_weighted else None,
251+
per_sample_weights=(
252+
f.weights().to(embedding_bag.weight.dtype)
253+
if self._is_weighted
254+
else None
255+
),
252256
).float()
253257
pooled_embeddings.append(res)
254258
return KeyedTensor(

0 commit comments

Comments
 (0)