Skip to content

Commit a3eee19

Browse files
basilwongfacebook-github-bot
authored andcommitted
Update Model Input Callable Protocol to enable configuring index/offset/length type (#2768)
Summary: Pull Request resolved: #2768 # Diff Specific Changes Updates the Model Input Callable Protocol in TorchRec to enable configuring the index/offset/length type. The changes include adding new parameters to the ModelInput class constructor, which allow users to specify the data type of indices, offsets, and lengths. # Context Doc: https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing Updating the TorchRec unit test suite to cover int32 and int64 indices/offets support. # Summary Specifically for the [test_model_parallel](https://www.internalfb.com/code/fbsource/[3505ccb75a649a7d21218bcda126d1e8392afc5a]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) suite that I am looking at the change appears to be fairly straightforward. 1.The [ModelParallelTestShared](https://www.internalfb.com/code/fbsource/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) class defines a [test suite python library](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/test_utils/TARGETS?lines=65-69) referenced by multiple unit tests in the TorchRec codebase including [test_model_parallel_nccl](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/tests/TARGETS?lines=85-100) in which we are particularly interested in for this particular case. The method all of the unit tests in this class use is [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132). Within the `_test_sharding` function, the "callable" argument input to the [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) function is [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) which shows us how the input data/model is generated. Additional arguments will need to be added to both the [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132) and [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) functions. 2.The [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) function is where we define additional kwargs. This function leverages the [`gen_model_and_input`](https://www.internalfb.com/code/fbsource/[f7e6a3281d924b465e0e90ff079aa9df83ae9530]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=131) to define the test model and more importantly for our purposes the input tables. ``` generate=(cast(VariableBatchModelInputCallable, ModelInput.generate_variable_batch_input) if variable_batch_per_feature else ModelInput.generate), ``` 3.The [ModelInput](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=48) class' [`generate`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=55) and [`generate_variable_batch_input`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=589) methods are used to generate the input tensors used in the unit tests. All we need to do is add new arguments that enable configuring the index/offset type of the tables. # Diff stack change summary: a. Update the generate_variable_batch_input to enable configuring index/offset/length type b. Update the generate to enable configuring index/offset/length type c. Update Model Input Callable Protocol to Enable Configuring index/offset/length type d. test_model_parallel: new test for different table index types e. Deprecate long_indices argument for torch.dtype arguments Reviewed By: TroyGarden Differential Revision: D70055498 fbshipit-source-id: 2e06c3d647e206bd92aa2becb05fd27f05818f62
1 parent cc48945 commit a3eee19

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def __call__(
126126
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
127127
] = None,
128128
variable_batch_size: bool = False,
129+
use_offsets: bool = False,
130+
indices_dtype: torch.dtype = torch.int64,
131+
offsets_dtype: torch.dtype = torch.int64,
132+
lengths_dtype: torch.dtype = torch.int64,
129133
long_indices: bool = True,
130134
) -> Tuple["ModelInput", List["ModelInput"]]: ...
131135

@@ -140,6 +144,10 @@ def __call__(
140144
weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
141145
pooling_avg: int = 10,
142146
global_constant_batch: bool = False,
147+
use_offsets: bool = False,
148+
indices_dtype: torch.dtype = torch.int64,
149+
offsets_dtype: torch.dtype = torch.int64,
150+
lengths_dtype: torch.dtype = torch.int64,
143151
) -> Tuple["ModelInput", List["ModelInput"]]: ...
144152

145153

@@ -161,10 +169,14 @@ def gen_model_and_input(
161169
variable_batch_size: bool = False,
162170
batch_size: int = 4,
163171
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
164-
long_indices: bool = True,
172+
use_offsets: bool = False,
173+
indices_dtype: torch.dtype = torch.int64,
174+
offsets_dtype: torch.dtype = torch.int64,
175+
lengths_dtype: torch.dtype = torch.int64,
165176
global_constant_batch: bool = False,
166177
num_inputs: int = 1,
167178
input_type: str = "kjt", # "kjt" or "td"
179+
long_indices: bool = True,
168180
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
169181
torch.manual_seed(0)
170182
if dedup_feature_names:
@@ -205,6 +217,10 @@ def gen_model_and_input(
205217
tables=tables,
206218
weighted_tables=weighted_tables or [],
207219
global_constant_batch=global_constant_batch,
220+
use_offsets=use_offsets,
221+
indices_dtype=indices_dtype,
222+
offsets_dtype=offsets_dtype,
223+
lengths_dtype=lengths_dtype,
208224
)
209225
)
210226
elif generate == ModelInput.generate:
@@ -218,8 +234,12 @@ def gen_model_and_input(
218234
num_float_features=num_float_features,
219235
variable_batch_size=variable_batch_size,
220236
batch_size=batch_size,
221-
long_indices=long_indices,
222237
input_type=input_type,
238+
use_offsets=use_offsets,
239+
indices_dtype=indices_dtype,
240+
offsets_dtype=offsets_dtype,
241+
lengths_dtype=lengths_dtype,
242+
long_indices=long_indices,
223243
)
224244
)
225245
else:
@@ -233,6 +253,10 @@ def gen_model_and_input(
233253
num_float_features=num_float_features,
234254
variable_batch_size=variable_batch_size,
235255
batch_size=batch_size,
256+
use_offsets=use_offsets,
257+
indices_dtype=indices_dtype,
258+
offsets_dtype=offsets_dtype,
259+
lengths_dtype=lengths_dtype,
236260
long_indices=long_indices,
237261
)
238262
)
@@ -336,6 +360,10 @@ def sharding_single_rank_test(
336360
input_type: str = "kjt", # "kjt" or "td"
337361
allow_zero_batch_size: bool = False,
338362
custom_all_reduce: bool = False, # 2D parallel
363+
use_offsets: bool = False,
364+
indices_dtype: torch.dtype = torch.int64,
365+
offsets_dtype: torch.dtype = torch.int64,
366+
lengths_dtype: torch.dtype = torch.int64,
339367
) -> None:
340368
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
341369
batch_size = (
@@ -363,6 +391,10 @@ def sharding_single_rank_test(
363391
feature_processor_modules=feature_processor_modules,
364392
global_constant_batch=global_constant_batch,
365393
input_type=input_type,
394+
use_offsets=use_offsets,
395+
indices_dtype=indices_dtype,
396+
offsets_dtype=offsets_dtype,
397+
lengths_dtype=lengths_dtype,
366398
)
367399
global_model = global_model.to(ctx.device)
368400
global_input = inputs[0][0].to(ctx.device)

0 commit comments

Comments
 (0)