Skip to content

Commit 0108153

Browse files
basilwongfacebook-github-bot
authored andcommitted
test_model_parallel: new test for different table index types (#2770)
Summary: Pull Request resolved: #2770 # Diff Specific Changes This diff introduces a new test for different table index types in the test_model_parallel.py file. The code changes in the file include adding new arguments to the TestModelParallel class constructor to allow for specifying the index type of the table. The new test will be used to verify the correctness when different index types are used as input. # Context Doc: https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing Updating the TorchRec unit test suite to cover int32 and int64 indices/offets support. # Summary Specifically for the [test_model_parallel](https://www.internalfb.com/code/fbsource/[3505ccb75a649a7d21218bcda126d1e8392afc5a]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) suite that I am looking at the change appears to be fairly straightforward. 1.The [ModelParallelTestShared](https://www.internalfb.com/code/fbsource/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) class defines a [test suite python library](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/test_utils/TARGETS?lines=65-69) referenced by multiple unit tests in the TorchRec codebase including [test_model_parallel_nccl](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/tests/TARGETS?lines=85-100) in which we are particularly interested in for this particular case. The method all of the unit tests in this class use is [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132). Within the `_test_sharding` function, the "callable" argument input to the [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) function is [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) which shows us how the input data/model is generated. Additional arguments will need to be added to both the [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132) and [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) functions. 2.The [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) function is where we define additional kwargs. This function leverages the [`gen_model_and_input`](https://www.internalfb.com/code/fbsource/[f7e6a3281d924b465e0e90ff079aa9df83ae9530]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=131) to define the test model and more importantly for our purposes the input tables. ``` generate=(cast(VariableBatchModelInputCallable, ModelInput.generate_variable_batch_input) if variable_batch_per_feature else ModelInput.generate), ``` 3.The [ModelInput](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=48) class' [`generate`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=55) and [`generate_variable_batch_input`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=589) methods are used to generate the input tensors used in the unit tests. All we need to do is add new arguments that enable configuring the index/offset type of the tables. # Diff stack change summary: a. Update the generate_variable_batch_input to enable configuring index/offset/length type b. Update the generate to enable configuring index/offset/length type c. Update Model Input Callable Protocol to Enable Configuring index/offset/length type d. test_model_parallel: new test for different table index types e. Deprecate long_indices argument for torch.dtype arguments Reviewed By: TroyGarden, ys97529 Differential Revision: D70055534 fbshipit-source-id: d338d74b30ffbecb36cf3bf91481e28355b610b3
1 parent a3eee19 commit 0108153

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def _test_sharding(
152152
use_inter_host_allreduce: bool = False,
153153
allow_zero_batch_size: bool = False,
154154
custom_all_reduce: bool = False,
155+
use_offsets: bool = False,
156+
indices_dtype: torch.dtype = torch.int64,
157+
offsets_dtype: torch.dtype = torch.int64,
158+
lengths_dtype: torch.dtype = torch.int64,
155159
) -> None:
156160
self._build_tables_and_groups(data_type=data_type)
157161
self._run_multi_process_test(
@@ -176,6 +180,10 @@ def _test_sharding(
176180
use_inter_host_allreduce=use_inter_host_allreduce,
177181
allow_zero_batch_size=allow_zero_batch_size,
178182
custom_all_reduce=custom_all_reduce,
183+
use_offsets=use_offsets,
184+
indices_dtype=indices_dtype,
185+
offsets_dtype=offsets_dtype,
186+
lengths_dtype=lengths_dtype,
179187
)
180188

181189

@@ -901,3 +909,58 @@ def test_sharding_grid_8gpu(
901909
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
902910
pooling=pooling,
903911
)
912+
913+
@unittest.skipIf(
914+
torch.cuda.device_count() <= 1,
915+
"Not enough GPUs, this test requires at least two GPUs",
916+
)
917+
# pyre-fixme[56]
918+
@given(
919+
dtype=st.sampled_from([torch.int32, torch.int64]),
920+
use_offsets=st.booleans(),
921+
sharder_type=st.sampled_from(
922+
[
923+
SharderType.EMBEDDING_BAG_COLLECTION.value,
924+
]
925+
),
926+
kernel_type=st.sampled_from(
927+
[
928+
EmbeddingComputeKernel.FUSED.value,
929+
],
930+
),
931+
)
932+
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
933+
def test_sharding_diff_table_index_type(
934+
self,
935+
dtype: torch.dtype,
936+
use_offsets: bool,
937+
sharder_type: str,
938+
kernel_type: str,
939+
) -> None:
940+
"""
941+
Test that the model correctly handles input indices and offsets
942+
with both int32 and int64 data types.
943+
"""
944+
sharders = [
945+
cast(
946+
ModuleSharder[nn.Module],
947+
create_test_sharder(
948+
sharder_type=sharder_type,
949+
sharding_type=ShardingType.ROW_WISE.value, # or any other relevant sharding type
950+
kernel_type=kernel_type,
951+
device=self.device,
952+
),
953+
),
954+
]
955+
# TODO - how to pass dtype so that sampled data uses different type indices/offsets?
956+
self._test_sharding(
957+
sharders=sharders,
958+
backend=self.backend,
959+
apply_optimizer_in_backward_config=None,
960+
variable_batch_size=False,
961+
pooling=PoolingType.SUM,
962+
use_offsets=use_offsets,
963+
indices_dtype=dtype,
964+
offsets_dtype=dtype,
965+
lengths_dtype=dtype,
966+
)

0 commit comments

Comments
 (0)