Skip to content

Commit b0919ce

Browse files
Shuangping Liufacebook-github-bot
Shuangping Liu
authored andcommitted
Fix empty sharding constraints in test_model_parallel.py (#2998)
Summary: Pull Request resolved: #2998 #### Context Several unit tests in `test_model_parallel.py` passed **empty constraints** into `self._test_sharding` because the constraints are generated using an empty `self.tables` before invoking `self._build_tables_and_groups`. Impacted tests are: * `test_sharding_twcw` * `test_sharding_variable_batch` * `test_sharding_multiple_kernels` #### Changes * Constraints only depend on table names. A new list `self.table_names` is created in `setUp()` stage to be used to construct constraints. * Updates `self._build_tables_and_groups` to use the generated table names. * Increases `max_examples` for `test_sharding_multiple_kernels` to cover both FP32 and FP16 cases. Reviewed By: TroyGarden Differential Revision: D75306149 fbshipit-source-id: b93f7656e45a8c79393a1c347437f757aac07557
1 parent 450b887 commit b0919ce

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def setUp(self, backend: str = "nccl") -> None:
4040
self.num_weighted_features = 2
4141
self.num_shared_features = 2
4242

43+
self.table_names = [
44+
"table_" + str(i)
45+
for i in range(self.num_features + self.num_shared_features)
46+
]
4347
self.tables = []
4448
self.mean_tables = []
4549
self.weighted_tables = []
@@ -63,7 +67,7 @@ def _build_tables_and_groups(
6367
EmbeddingBagConfig(
6468
num_embeddings=(i + 1) * 10,
6569
embedding_dim=(i + 2) * 8,
66-
name="table_" + str(i),
70+
name=self.table_names[i],
6771
feature_names=["feature_" + str(i)],
6872
data_type=data_type,
6973
)
@@ -73,7 +77,7 @@ def _build_tables_and_groups(
7377
EmbeddingBagConfig(
7478
num_embeddings=(i + 1) * 10,
7579
embedding_dim=(i + 2) * 8,
76-
name="table_" + str(i + self.num_features),
80+
name=self.table_names[i + self.num_features],
7781
feature_names=["feature_" + str(i)],
7882
data_type=data_type,
7983
)
@@ -85,7 +89,7 @@ def _build_tables_and_groups(
8589
EmbeddingBagConfig(
8690
num_embeddings=(i + 1) * 10,
8791
embedding_dim=(i + 2) * 8,
88-
name="table_" + str(i),
92+
name=self.table_names[i],
8993
feature_names=["feature_" + str(i)],
9094
pooling=PoolingType.MEAN,
9195
data_type=data_type,
@@ -97,7 +101,7 @@ def _build_tables_and_groups(
97101
EmbeddingBagConfig(
98102
num_embeddings=(i + 1) * 10,
99103
embedding_dim=(i + 2) * 8,
100-
name="table_" + str(i + self.num_features),
104+
name=self.table_names[i + self.num_features],
101105
feature_names=["feature_" + str(i)],
102106
pooling=PoolingType.MEAN,
103107
data_type=data_type,
@@ -385,8 +389,8 @@ def test_sharding_cw(
385389
backend=self.backend,
386390
qcomms_config=qcomms_config,
387391
constraints={
388-
table.name: ParameterConstraints(min_partition=4)
389-
for table in self.tables
392+
table_name: ParameterConstraints(min_partition=4)
393+
for table_name in self.table_names
390394
},
391395
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
392396
variable_batch_size=variable_batch_size,
@@ -466,8 +470,8 @@ def test_sharding_twcw(
466470
backend=self.backend,
467471
qcomms_config=qcomms_config,
468472
constraints={
469-
table.name: ParameterConstraints(min_partition=4)
470-
for table in self.tables
473+
table_name: ParameterConstraints(min_partition=4)
474+
for table_name in self.table_names
471475
},
472476
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
473477
variable_batch_size=variable_batch_size,
@@ -681,8 +685,8 @@ def test_sharding_variable_batch(
681685
],
682686
backend=self.backend,
683687
constraints={
684-
table.name: ParameterConstraints(min_partition=4)
685-
for table in self.tables
688+
table_name: ParameterConstraints(min_partition=4)
689+
for table_name in self.table_names
686690
},
687691
variable_batch_per_feature=True,
688692
has_weighted_tables=False,
@@ -700,24 +704,25 @@ def test_sharding_variable_batch(
700704
sharding_type=st.just(ShardingType.COLUMN_WISE.value),
701705
data_type=st.sampled_from([DataType.FP32, DataType.FP16]),
702706
)
703-
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
707+
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
704708
def test_sharding_multiple_kernels(
705709
self, sharding_type: str, data_type: DataType
706710
) -> None:
707711
if self.backend == "gloo":
708712
self.skipTest("ProcessGroupGloo does not support reduce_scatter")
713+
fused_params = {"prefetch_pipeline": True}
709714
constraints = {
710-
table.name: ParameterConstraints(
715+
table_name: ParameterConstraints(
711716
min_partition=4,
712717
compute_kernels=(
713718
[EmbeddingComputeKernel.FUSED.value]
714719
if i % 2 == 0
715720
else [EmbeddingComputeKernel.FUSED_UVM_CACHING.value]
716721
),
722+
sharding_types=[sharding_type],
717723
)
718-
for i, table in enumerate(self.tables)
724+
for i, table_name in enumerate(self.table_names)
719725
}
720-
fused_params = {"prefetch_pipeline": True}
721726
self._test_sharding(
722727
# pyre-ignore[6]
723728
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],

0 commit comments

Comments
 (0)