Skip to content

Commit cc48945

Browse files
basilwongfacebook-github-bot
authored andcommitted
Update the generate function to enable configuring index/offset/length type (#2767)
Summary: Pull Request resolved: #2767 # Diff Specific Changes Adding new parameters for the index/offset/length type, as well as modifying the [generate](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=55) function to use these parameters. In this diff we maintain the `long_indices` argument for backwards compatibility. Will create a follow up diff to deprecate `long_indices` as an argument (since it is redundant) as well as references depending on it downstream. # Context Doc: https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing Updating the TorchRec unit test suite to cover int32 and int64 indices/offets support. # Summary Specifically for the [test_model_parallel](https://www.internalfb.com/code/fbsource/[3505ccb75a649a7d21218bcda126d1e8392afc5a]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) suite that I am looking at the change appears to be fairly straightforward. 1.The [ModelParallelTestShared](https://www.internalfb.com/code/fbsource/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) class defines a [test suite python library](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/test_utils/TARGETS?lines=65-69) referenced by multiple unit tests in the TorchRec codebase including [test_model_parallel_nccl](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/tests/TARGETS?lines=85-100) in which we are particularly interested in for this particular case. The method all of the unit tests in this class use is [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132). Within the `_test_sharding` function, the "callable" argument input to the [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) function is [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) which shows us how the input data/model is generated. Additional arguments will need to be added to both the [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132) and [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) functions. 2.The [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) function is where we define additional kwargs. This function leverages the [`gen_model_and_input`](https://www.internalfb.com/code/fbsource/[f7e6a3281d924b465e0e90ff079aa9df83ae9530]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=131) to define the test model and more importantly for our purposes the input tables. ``` generate=(cast(VariableBatchModelInputCallable, ModelInput.generate_variable_batch_input) if variable_batch_per_feature else ModelInput.generate), ``` 3.The [ModelInput](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=48) class' [`generate`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=55) and [`generate_variable_batch_input`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=589) methods are used to generate the input tensors used in the unit tests. All we need to do is add new arguments that enable configuring the index/offset type of the tables. # Diff stack change summary: a. Update the generate_variable_batch_input to enable configuring index/offset/length type b. Update the generate to enable configuring index/offset/length type c. Update Model Input Callable Protocol to Enable Configuring index/offset/length type d. test_model_parallel: new test for different table index types e. Deprecate long_indices argument for torch.dtype arguments Reviewed By: TroyGarden Differential Revision: D70055042 fbshipit-source-id: f0563bca57047b41fbefc61f177ab92caec48a21
1 parent 449bb82 commit cc48945

File tree

3 files changed

+108
-43
lines changed

3 files changed

+108
-43
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 105 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,30 @@ def generate(
7171
]
7272
] = None,
7373
variable_batch_size: bool = False,
74-
long_indices: bool = True,
7574
tables_pooling: Optional[List[int]] = None,
7675
weighted_tables_pooling: Optional[List[int]] = None,
7776
randomize_indices: bool = True,
7877
device: Optional[torch.device] = None,
7978
max_feature_lengths: Optional[List[int]] = None,
8079
input_type: str = "kjt",
80+
use_offsets: bool = False,
81+
indices_dtype: torch.dtype = torch.int64,
82+
offsets_dtype: torch.dtype = torch.int64,
83+
lengths_dtype: torch.dtype = torch.int64,
84+
long_indices: bool = True, # TODO - remove this once code base is updated to support more than long_indices spec
8185
) -> Tuple["ModelInput", List["ModelInput"]]:
8286
"""
8387
Returns a global (single-rank training) batch
8488
and a list of local (multi-rank training) batches of world_size.
8589
"""
86-
90+
if long_indices:
91+
indices_dtype = torch.int64
92+
lengths_dtype = torch.int64
93+
use_offsets = False
94+
else:
95+
indices_dtype = torch.int32
96+
lengths_dtype = torch.int32
97+
use_offsets = False
8798
batch_size_by_rank = [batch_size] * world_size
8899
if variable_batch_size:
89100
batch_size_by_rank = [
@@ -119,7 +130,6 @@ def _validate_pooling_factor(
119130
if tables[idx].num_embeddings_post_pruning is not None
120131
else tables[idx].num_embeddings
121132
)
122-
123133
idlist_features_to_max_length[feature] = (
124134
max_feature_lengths[feature_idx] if max_feature_lengths else None
125135
)
@@ -144,18 +154,21 @@ def _validate_pooling_factor(
144154

145155
idlist_pooling_factor = list(idlist_features_to_pooling_factor.values())
146156
idscore_pooling_factor = weighted_tables_pooling
147-
148157
idlist_max_lengths = list(idlist_features_to_max_length.values())
149158

150159
# Generate global batch.
151160
global_idlist_lengths = []
152161
global_idlist_indices = []
162+
global_idlist_offsets = []
163+
153164
global_idscore_lengths = []
154165
global_idscore_indices = []
166+
global_idscore_offsets = []
155167
global_idscore_weights = []
156168

157169
for idx in range(len(idlist_ind_ranges)):
158170
ind_range = idlist_ind_ranges[idx]
171+
159172
if idlist_pooling_factor:
160173
lengths_ = torch.max(
161174
torch.normal(
@@ -165,17 +178,19 @@ def _validate_pooling_factor(
165178
device=device,
166179
),
167180
torch.tensor(1.0, device=device),
168-
).int()
181+
).to(lengths_dtype)
169182
else:
170183
lengths_ = torch.abs(
171184
torch.randn(batch_size * world_size, device=device) + pooling_avg,
172-
).int()
185+
).to(lengths_dtype)
173186

174187
if idlist_max_lengths[idx]:
175188
lengths_ = torch.clamp(lengths_, max=idlist_max_lengths[idx])
176189

177190
if variable_batch_size:
178-
lengths = torch.zeros(batch_size * world_size, device=device).int()
191+
lengths = torch.zeros(batch_size * world_size, device=device).to(
192+
lengths_dtype
193+
)
179194
for r in range(world_size):
180195
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
181196
lengths_[
@@ -186,42 +201,30 @@ def _validate_pooling_factor(
186201
lengths = lengths_
187202

188203
num_indices = cast(int, torch.sum(lengths).item())
204+
189205
if randomize_indices:
190206
indices = torch.randint(
191207
0,
192208
ind_range,
193209
(num_indices,),
194-
dtype=torch.long if long_indices else torch.int32,
210+
dtype=indices_dtype,
195211
device=device,
196212
)
197213
else:
198214
indices = torch.zeros(
199-
(num_indices),
200-
dtype=torch.long if long_indices else torch.int32,
215+
(num_indices,),
216+
dtype=indices_dtype,
201217
device=device,
202218
)
219+
220+
# Calculate offsets from lengths
221+
offsets = torch.cat(
222+
[torch.tensor([0], device=device), lengths.cumsum(0)]
223+
).to(offsets_dtype)
224+
203225
global_idlist_lengths.append(lengths)
204226
global_idlist_indices.append(indices)
205-
206-
if input_type == "kjt":
207-
global_idlist_input = KeyedJaggedTensor(
208-
keys=idlist_features,
209-
values=torch.cat(global_idlist_indices),
210-
lengths=torch.cat(global_idlist_lengths),
211-
)
212-
elif input_type == "td":
213-
dict_of_nt = {
214-
k: torch.nested.nested_tensor_from_jagged(
215-
values=values,
216-
lengths=lengths,
217-
)
218-
for k, values, lengths in zip(
219-
idlist_features, global_idlist_indices, global_idlist_lengths
220-
)
221-
}
222-
global_idlist_input = TensorDict(source=dict_of_nt)
223-
else:
224-
raise ValueError(f"For IdList features, unknown input type {input_type}")
227+
global_idlist_offsets.append(offsets)
225228

226229
for idx, ind_range in enumerate(idscore_ind_ranges):
227230
lengths_ = torch.abs(
@@ -231,9 +234,12 @@ def _validate_pooling_factor(
231234
if idscore_pooling_factor
232235
else pooling_avg
233236
)
234-
).int()
237+
).to(lengths_dtype)
238+
235239
if variable_batch_size:
236-
lengths = torch.zeros(batch_size * world_size, device=device).int()
240+
lengths = torch.zeros(batch_size * world_size, device=device).to(
241+
lengths_dtype
242+
)
237243
for r in range(world_size):
238244
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = (
239245
lengths_[
@@ -242,39 +248,68 @@ def _validate_pooling_factor(
242248
)
243249
else:
244250
lengths = lengths_
251+
245252
num_indices = cast(int, torch.sum(lengths).item())
253+
246254
if randomize_indices:
247255
indices = torch.randint(
248256
0,
249257
# pyre-ignore [6]
250258
ind_range,
251259
(num_indices,),
252-
dtype=torch.long if long_indices else torch.int32,
260+
dtype=indices_dtype,
253261
device=device,
254262
)
255263
else:
256264
indices = torch.zeros(
257-
(num_indices),
258-
dtype=torch.long if long_indices else torch.int32,
265+
(num_indices,),
266+
dtype=indices_dtype,
259267
device=device,
260268
)
261269
weights = torch.rand((num_indices,), device=device)
270+
# Calculate offsets from lengths
271+
offsets = torch.cat(
272+
[torch.tensor([0], device=device), lengths.cumsum(0)]
273+
).to(offsets_dtype)
274+
262275
global_idscore_lengths.append(lengths)
263276
global_idscore_indices.append(indices)
264277
global_idscore_weights.append(weights)
278+
global_idscore_offsets.append(offsets)
265279

266280
if input_type == "kjt":
281+
global_idlist_input = KeyedJaggedTensor(
282+
keys=idlist_features,
283+
values=torch.cat(global_idlist_indices),
284+
offsets=torch.cat(global_idlist_offsets) if use_offsets else None,
285+
lengths=torch.cat(global_idlist_lengths) if not use_offsets else None,
286+
)
287+
267288
global_idscore_input = (
268289
KeyedJaggedTensor(
269290
keys=idscore_features,
270291
values=torch.cat(global_idscore_indices),
271-
lengths=torch.cat(global_idscore_lengths),
292+
offsets=torch.cat(global_idscore_offsets) if use_offsets else None,
293+
lengths=(
294+
torch.cat(global_idscore_lengths) if not use_offsets else None
295+
),
272296
weights=torch.cat(global_idscore_weights),
273297
)
274298
if global_idscore_indices
275299
else None
276300
)
277301
elif input_type == "td":
302+
dict_of_nt = {
303+
k: torch.nested.nested_tensor_from_jagged(
304+
values=values,
305+
lengths=lengths,
306+
)
307+
for k, values, lengths in zip(
308+
idlist_features, global_idlist_indices, global_idlist_lengths
309+
)
310+
}
311+
global_idlist_input = TensorDict(source=dict_of_nt)
312+
278313
assert (
279314
len(idscore_features) == 0
280315
), "TensorDict does not support weighted features"
@@ -295,14 +330,20 @@ def _validate_pooling_factor(
295330

296331
# Split global batch into local batches.
297332
local_inputs = []
333+
298334
for r in range(world_size):
299335
local_idlist_lengths = []
300336
local_idlist_indices = []
337+
local_idlist_offsets = []
338+
301339
local_idscore_lengths = []
302340
local_idscore_indices = []
303341
local_idscore_weights = []
342+
local_idscore_offsets = []
304343

305-
for lengths, indices in zip(global_idlist_lengths, global_idlist_indices):
344+
for lengths, indices, offsets in zip(
345+
global_idlist_lengths, global_idlist_indices, global_idlist_offsets
346+
):
306347
local_idlist_lengths.append(
307348
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]]
308349
)
@@ -312,9 +353,15 @@ def _validate_pooling_factor(
312353
local_idlist_indices.append(
313354
indices[lengths_cumsum[r] : lengths_cumsum[r + 1]]
314355
)
356+
local_idlist_offsets.append(
357+
offsets[r * batch_size : r * batch_size + batch_size_by_rank[r] + 1]
358+
)
315359

316-
for lengths, indices, weights in zip(
317-
global_idscore_lengths, global_idscore_indices, global_idscore_weights
360+
for lengths, indices, weights, offsets in zip(
361+
global_idscore_lengths,
362+
global_idscore_indices,
363+
global_idscore_weights,
364+
global_idscore_offsets,
318365
):
319366
local_idscore_lengths.append(
320367
lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]]
@@ -329,18 +376,32 @@ def _validate_pooling_factor(
329376
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
330377
)
331378

379+
local_idscore_offsets.append(
380+
offsets[r * batch_size : r * batch_size + batch_size_by_rank[r] + 1]
381+
)
382+
332383
if input_type == "kjt":
333384
local_idlist_input = KeyedJaggedTensor(
334385
keys=idlist_features,
335386
values=torch.cat(local_idlist_indices),
336-
lengths=torch.cat(local_idlist_lengths),
387+
offsets=torch.cat(local_idlist_offsets) if use_offsets else None,
388+
lengths=(
389+
torch.cat(local_idlist_lengths) if not use_offsets else None
390+
),
337391
)
338392

339393
local_idscore_input = (
340394
KeyedJaggedTensor(
341395
keys=idscore_features,
342396
values=torch.cat(local_idscore_indices),
343-
lengths=torch.cat(local_idscore_lengths),
397+
offsets=(
398+
torch.cat(local_idscore_offsets) if use_offsets else None
399+
),
400+
lengths=(
401+
torch.cat(local_idscore_lengths)
402+
if not use_offsets
403+
else None
404+
),
344405
weights=torch.cat(local_idscore_weights),
345406
)
346407
if local_idscore_indices
@@ -353,15 +414,16 @@ def _validate_pooling_factor(
353414
lengths=lengths,
354415
)
355416
for k, values, lengths in zip(
356-
idlist_features, local_idlist_indices, local_idlist_lengths
417+
idlist_features,
418+
local_idlist_indices,
419+
local_idlist_lengths,
357420
)
358421
}
359422
local_idlist_input = TensorDict(source=dict_of_nt)
360423
assert (
361424
len(idscore_features) == 0
362425
), "TensorDict does not support weighted features"
363426
local_idscore_input = None
364-
365427
else:
366428
raise ValueError(
367429
f"For weighted features, unknown input type {input_type}"

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _test_sharded_forward(
9393
dedup_tables: Optional[List[EmbeddingTableConfig]] = None,
9494
weighted_tables: Optional[List[EmbeddingTableConfig]] = None,
9595
constraints: Optional[Dict[str, ParameterConstraints]] = None,
96+
# pyre-ignore [9]
9697
generate: ModelInputCallable = ModelInput.generate,
9798
) -> None:
9899
default_rank = 0

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def gen_model_and_input(
148148
tables: List[EmbeddingTableConfig],
149149
embedding_groups: Dict[str, List[str]],
150150
world_size: int,
151+
# pyre-ignore [9]
151152
generate: Union[
152153
ModelInputCallable, VariableBatchModelInputCallable
153154
] = ModelInput.generate,
@@ -344,6 +345,7 @@ def sharding_single_rank_test(
344345
(global_model, inputs) = gen_model_and_input(
345346
model_class=model_class,
346347
tables=tables,
348+
# pyre-ignore [6]
347349
generate=(
348350
cast(
349351
VariableBatchModelInputCallable,

0 commit comments

Comments
 (0)