Skip to content

Commit 856ff3c

Browse files
basilwongfacebook-github-bot
authored andcommitted
Update the generate_variable_batch_input function to enable configuring index/offset/length type (#2765)
Summary: Pull Request resolved: #2765 # Diff Specific Changes Adding new parameters for the index/offset/length type, as well as modifying the [generate_variable_batch_input](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=589) function + helper functions to use these parameters. # Context Doc: https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing Updating the TorchRec unit test suite to cover int32 and int64 indices/offets support. # Summary Specifically for the [test_model_parallel](https://www.internalfb.com/code/fbsource/[3505ccb75a649a7d21218bcda126d1e8392afc5a]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) suite that I am looking at the change appears to be fairly straightforward. 1.The [ModelParallelTestShared](https://www.internalfb.com/code/fbsource/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=34) class defines a [test suite python library](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/test_utils/TARGETS?lines=65-69) referenced by multiple unit tests in the TorchRec codebase including [test_model_parallel_nccl](https://www.internalfb.com/code/fbsource/[cbd0bd0020a7afbec4922d8abc0d88b7d45cba56]/fbcode/torchrec/distributed/tests/TARGETS?lines=85-100) in which we are particularly interested in for this particular case. The method all of the unit tests in this class use is [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132). Within the `_test_sharding` function, the "callable" argument input to the [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) function is [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) which shows us how the input data/model is generated. Additional arguments will need to be added to both the [`_test_sharding`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_model_parallel.py?lines=132) and [`_run_multi_process_test`](https://www.internalfb.com/code/symbol/fbsource/py/fbcode/caffe2.torch.fb.hpc.tests.sparse_data_dist_test.SparseDataDistTest._run_multi_process_test) functions. 2.The [`sharding_single_rank_test`](https://www.internalfb.com/code/fbsource/[fa9508a29b62ce57681ee73cd6d4cac56f153a58]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=296) function is where we define additional kwargs. This function leverages the [`gen_model_and_input`](https://www.internalfb.com/code/fbsource/[f7e6a3281d924b465e0e90ff079aa9df83ae9530]/fbcode/torchrec/distributed/test_utils/test_sharding.py?lines=131) to define the test model and more importantly for our purposes the input tables. ``` generate=(cast(VariableBatchModelInputCallable, ModelInput.generate_variable_batch_input) if variable_batch_per_feature else ModelInput.generate), ``` 3.The [ModelInput](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=48) class' [`generate`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=55) and [`generate_variable_batch_input`](https://www.internalfb.com/code/fbsource/[4217c068fa966d569d2042a7263cefe1a06dc87a]/fbcode/torchrec/distributed/test_utils/test_model.py?lines=589) methods are used to generate the input tensors used in the unit tests. All we need to do is add new arguments that enable configuring the index/offset type of the tables. # Diff stack change summary: a. Update the generate_variable_batch_input to enable configuring index/offset/length type b. Update the generate to enable configuring index/offset/length type c. Update Model Input Callable Protocol to Enable Configuring index/offset/length type d. test_model_parallel: new test for different table index types e. Deprecate long_indices argument for torch.dtype arguments Reviewed By: TroyGarden Differential Revision: D70054939 fbshipit-source-id: 9bee9478d21cfcba694ae866437e9d1ee4910c75
1 parent f66ba3f commit 856ff3c

File tree

1 file changed

+147
-50
lines changed

1 file changed

+147
-50
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 147 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -396,29 +396,42 @@ def _generate_variable_batch_local_features(
396396
strides_per_rank_per_feature: Dict[int, Dict[str, int]],
397397
inverse_indices_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]],
398398
weights_per_rank_per_feature: Optional[Dict[int, Dict[str, torch.Tensor]]],
399+
use_offsets: bool,
400+
indices_dtype: torch.dtype,
401+
offsets_dtype: torch.dtype,
402+
lengths_dtype: torch.dtype,
399403
) -> List[KeyedJaggedTensor]:
400404
local_kjts = []
401405
keys = list(feature_num_embeddings.keys())
406+
402407
for rank in range(world_size):
403408
lengths_per_rank_per_feature[rank] = {}
404409
values_per_rank_per_feature[rank] = {}
405410
strides_per_rank_per_feature[rank] = {}
406411
inverse_indices_per_rank_per_feature[rank] = {}
412+
407413
if weights_per_rank_per_feature is not None:
408414
weights_per_rank_per_feature[rank] = {}
409415

410416
for key, num_embeddings in feature_num_embeddings.items():
411417
batch_size = random.randint(1, average_batch_size * dedup_factor - 1)
412-
lengths = torch.randint(low=0, high=5, size=(batch_size,))
418+
lengths = torch.randint(
419+
low=0, high=5, size=(batch_size,), dtype=lengths_dtype
420+
)
413421
lengths_per_rank_per_feature[rank][key] = lengths
414422
lengths_sum = sum(lengths.tolist())
415-
values = torch.randint(0, num_embeddings, (lengths_sum,))
423+
values = torch.randint(
424+
0, num_embeddings, (lengths_sum,), dtype=indices_dtype
425+
)
416426
values_per_rank_per_feature[rank][key] = values
417427
if weights_per_rank_per_feature is not None:
418428
weights_per_rank_per_feature[rank][key] = torch.rand(lengths_sum)
419429
strides_per_rank_per_feature[rank][key] = batch_size
420430
inverse_indices_per_rank_per_feature[rank][key] = torch.randint(
421-
0, batch_size, (dedup_factor * average_batch_size,)
431+
0,
432+
batch_size,
433+
(dedup_factor * average_batch_size,),
434+
dtype=indices_dtype,
422435
)
423436

424437
values = torch.cat(list(values_per_rank_per_feature[rank].values()))
@@ -428,23 +441,40 @@ def _generate_variable_batch_local_features(
428441
if weights_per_rank_per_feature is not None
429442
else None
430443
)
431-
stride_per_key_per_rank = [
432-
[stride] for stride in strides_per_rank_per_feature[rank].values()
433-
]
434-
inverse_indices = (
435-
keys,
436-
torch.stack(list(inverse_indices_per_rank_per_feature[rank].values())),
437-
)
438-
local_kjts.append(
439-
KeyedJaggedTensor(
440-
keys=keys,
441-
values=values,
442-
lengths=lengths,
443-
weights=weights,
444-
stride_per_key_per_rank=stride_per_key_per_rank,
445-
inverse_indices=inverse_indices,
444+
445+
if use_offsets:
446+
offsets = torch.cat(
447+
[torch.tensor([0], dtype=offsets_dtype), lengths.cumsum(0)]
446448
)
447-
)
449+
local_kjts.append(
450+
KeyedJaggedTensor(
451+
keys=keys,
452+
values=values,
453+
offsets=offsets,
454+
weights=weights,
455+
)
456+
)
457+
else:
458+
stride_per_key_per_rank = [
459+
[stride] for stride in strides_per_rank_per_feature[rank].values()
460+
]
461+
inverse_indices = (
462+
keys,
463+
torch.stack(
464+
list(inverse_indices_per_rank_per_feature[rank].values())
465+
),
466+
)
467+
local_kjts.append(
468+
KeyedJaggedTensor(
469+
keys=keys,
470+
values=values,
471+
lengths=lengths,
472+
weights=weights,
473+
stride_per_key_per_rank=stride_per_key_per_rank,
474+
inverse_indices=inverse_indices,
475+
)
476+
)
477+
448478
return local_kjts
449479

450480
@staticmethod
@@ -457,6 +487,10 @@ def _generate_variable_batch_global_features(
457487
strides_per_rank_per_feature: Dict[int, Dict[str, int]],
458488
inverse_indices_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]],
459489
weights_per_rank_per_feature: Optional[Dict[int, Dict[str, torch.Tensor]]],
490+
use_offsets: bool,
491+
indices_dtype: torch.dtype,
492+
offsets_dtype: torch.dtype,
493+
lengths_dtype: torch.dtype,
460494
) -> KeyedJaggedTensor:
461495
global_values = []
462496
global_lengths = []
@@ -476,31 +510,41 @@ def _generate_variable_batch_global_features(
476510
inverse_indices_per_feature_per_rank.append(
477511
inverse_indices_per_rank_per_feature[rank][key]
478512
)
513+
479514
global_stride_per_key_per_rank.append([sum_stride])
480515

481516
inverse_indices_list: List[torch.Tensor] = []
517+
482518
for key in keys:
483519
accum_batch_size = 0
484520
inverse_indices = []
521+
485522
for rank in range(world_size):
486523
inverse_indices.append(
487524
inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size
488525
)
489526
accum_batch_size += strides_per_rank_per_feature[rank][key]
527+
490528
inverse_indices_list.append(torch.cat(inverse_indices))
529+
491530
global_inverse_indices = (keys, torch.stack(inverse_indices_list))
492531

493532
if global_constant_batch:
494533
global_offsets = []
534+
495535
for length in global_lengths:
496536
global_offsets.append(_to_offsets(length))
537+
497538
reindexed_lengths = []
539+
498540
for length, indices in zip(
499541
global_lengths, inverse_indices_per_feature_per_rank
500542
):
501543
reindexed_lengths.append(torch.index_select(length, 0, indices))
544+
502545
lengths = torch.cat(reindexed_lengths)
503546
reindexed_values, reindexed_weights = [], []
547+
504548
for i, (values, offsets, indices) in enumerate(
505549
zip(global_values, global_offsets, inverse_indices_per_feature_per_rank)
506550
):
@@ -510,25 +554,40 @@ def _generate_variable_batch_global_features(
510554
reindexed_weights.append(
511555
global_weights[i][offsets[idx] : offsets[idx + 1]]
512556
)
557+
513558
values = torch.cat(reindexed_values)
514559
weights = (
515560
torch.cat(reindexed_weights) if global_weights is not None else None
516561
)
517562
global_stride_per_key_per_rank = None
518563
global_inverse_indices = None
564+
519565
else:
520566
values = torch.cat(global_values)
521567
lengths = torch.cat(global_lengths)
522568
weights = torch.cat(global_weights) if global_weights is not None else None
523569

524-
return KeyedJaggedTensor(
525-
keys=keys,
526-
values=values,
527-
lengths=lengths,
528-
weights=weights,
529-
stride_per_key_per_rank=global_stride_per_key_per_rank,
530-
inverse_indices=global_inverse_indices,
531-
)
570+
if use_offsets:
571+
offsets = torch.cat(
572+
[torch.tensor([0], dtype=offsets_dtype), lengths.cumsum(0)]
573+
)
574+
return KeyedJaggedTensor(
575+
keys=keys,
576+
values=values,
577+
offsets=offsets,
578+
weights=weights,
579+
stride_per_key_per_rank=global_stride_per_key_per_rank,
580+
inverse_indices=global_inverse_indices,
581+
)
582+
else:
583+
return KeyedJaggedTensor(
584+
keys=keys,
585+
values=values,
586+
lengths=lengths,
587+
weights=weights,
588+
stride_per_key_per_rank=global_stride_per_key_per_rank,
589+
inverse_indices=global_inverse_indices,
590+
)
532591

533592
@staticmethod
534593
def _generate_variable_batch_features(
@@ -539,11 +598,17 @@ def _generate_variable_batch_features(
539598
world_size: int,
540599
dedup_factor: int,
541600
global_constant_batch: bool,
601+
use_offsets: bool,
602+
indices_dtype: torch.dtype,
603+
offsets_dtype: torch.dtype,
604+
lengths_dtype: torch.dtype,
542605
) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]:
543606
is_weighted = (
544607
True if tables and getattr(tables[0], "is_weighted", False) else False
545608
)
609+
546610
feature_num_embeddings = {}
611+
547612
for table in tables:
548613
for feature_name in table.feature_names:
549614
feature_num_embeddings[feature_name] = (
@@ -553,33 +618,42 @@ def _generate_variable_batch_features(
553618
)
554619

555620
local_kjts = []
621+
556622
values_per_rank_per_feature = {}
557623
lengths_per_rank_per_feature = {}
558624
strides_per_rank_per_feature = {}
559625
inverse_indices_per_rank_per_feature = {}
560626
weights_per_rank_per_feature = {} if is_weighted else None
561627

562628
local_kjts = ModelInput._generate_variable_batch_local_features(
563-
feature_num_embeddings,
564-
average_batch_size,
565-
world_size,
566-
dedup_factor,
567-
values_per_rank_per_feature,
568-
lengths_per_rank_per_feature,
569-
strides_per_rank_per_feature,
570-
inverse_indices_per_rank_per_feature,
571-
weights_per_rank_per_feature,
629+
feature_num_embeddings=feature_num_embeddings,
630+
average_batch_size=average_batch_size,
631+
world_size=world_size,
632+
dedup_factor=dedup_factor,
633+
values_per_rank_per_feature=values_per_rank_per_feature,
634+
lengths_per_rank_per_feature=lengths_per_rank_per_feature,
635+
strides_per_rank_per_feature=strides_per_rank_per_feature,
636+
inverse_indices_per_rank_per_feature=inverse_indices_per_rank_per_feature,
637+
weights_per_rank_per_feature=weights_per_rank_per_feature,
638+
use_offsets=use_offsets,
639+
indices_dtype=indices_dtype,
640+
offsets_dtype=offsets_dtype,
641+
lengths_dtype=lengths_dtype,
572642
)
573643

574644
global_kjt = ModelInput._generate_variable_batch_global_features(
575-
list(feature_num_embeddings.keys()),
576-
world_size,
577-
global_constant_batch,
578-
values_per_rank_per_feature,
579-
lengths_per_rank_per_feature,
580-
strides_per_rank_per_feature,
581-
inverse_indices_per_rank_per_feature,
582-
weights_per_rank_per_feature,
645+
keys=list(feature_num_embeddings.keys()),
646+
world_size=world_size,
647+
global_constant_batch=global_constant_batch,
648+
values_per_rank_per_feature=values_per_rank_per_feature,
649+
lengths_per_rank_per_feature=lengths_per_rank_per_feature,
650+
strides_per_rank_per_feature=strides_per_rank_per_feature,
651+
inverse_indices_per_rank_per_feature=inverse_indices_per_rank_per_feature,
652+
weights_per_rank_per_feature=weights_per_rank_per_feature,
653+
use_offsets=use_offsets,
654+
indices_dtype=indices_dtype,
655+
offsets_dtype=offsets_dtype,
656+
lengths_dtype=lengths_dtype,
583657
)
584658

585659
return (global_kjt, local_kjts)
@@ -601,30 +675,51 @@ def generate_variable_batch_input(
601675
] = None,
602676
pooling_avg: int = 10,
603677
global_constant_batch: bool = False,
678+
use_offsets: bool = False,
679+
indices_dtype: torch.dtype = torch.int64,
680+
offsets_dtype: torch.dtype = torch.int64,
681+
lengths_dtype: torch.dtype = torch.int64,
604682
) -> Tuple["ModelInput", List["ModelInput"]]:
605683
torch.manual_seed(100)
606684
random.seed(100)
607685
dedup_factor = 2
686+
608687
global_kjt, local_kjts = ModelInput._generate_variable_batch_features(
609-
tables, average_batch_size, world_size, dedup_factor, global_constant_batch
688+
tables=tables,
689+
average_batch_size=average_batch_size,
690+
world_size=world_size,
691+
dedup_factor=dedup_factor,
692+
global_constant_batch=global_constant_batch,
693+
use_offsets=use_offsets,
694+
indices_dtype=indices_dtype,
695+
offsets_dtype=offsets_dtype,
696+
lengths_dtype=lengths_dtype,
610697
)
698+
611699
if weighted_tables:
612700
global_score_kjt, local_score_kjts = (
613701
ModelInput._generate_variable_batch_features(
614-
weighted_tables,
615-
average_batch_size,
616-
world_size,
617-
dedup_factor,
618-
global_constant_batch,
702+
tables=weighted_tables,
703+
average_batch_size=average_batch_size,
704+
world_size=world_size,
705+
dedup_factor=dedup_factor,
706+
global_constant_batch=global_constant_batch,
707+
use_offsets=use_offsets,
708+
indices_dtype=indices_dtype,
709+
offsets_dtype=offsets_dtype,
710+
lengths_dtype=lengths_dtype,
619711
)
620712
)
621713
else:
622714
global_score_kjt, local_score_kjts = None, []
715+
623716
global_float = torch.rand(
624717
(dedup_factor * average_batch_size * world_size, num_float_features)
625718
)
719+
626720
local_model_input = []
627721
label_per_rank = []
722+
628723
for rank in range(world_size):
629724
label_per_rank.append(torch.rand(dedup_factor * average_batch_size))
630725
local_float = global_float[
@@ -644,12 +739,14 @@ def generate_variable_batch_input(
644739
float_features=local_float,
645740
),
646741
)
742+
647743
global_model_input = ModelInput(
648744
idlist_features=global_kjt,
649745
idscore_features=global_score_kjt,
650746
label=torch.cat(label_per_rank),
651747
float_features=global_float,
652748
)
749+
653750
return (global_model_input, local_model_input)
654751

655752
def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":

0 commit comments

Comments
 (0)