@@ -396,29 +396,42 @@ def _generate_variable_batch_local_features(
396
396
strides_per_rank_per_feature : Dict [int , Dict [str , int ]],
397
397
inverse_indices_per_rank_per_feature : Dict [int , Dict [str , torch .Tensor ]],
398
398
weights_per_rank_per_feature : Optional [Dict [int , Dict [str , torch .Tensor ]]],
399
+ use_offsets : bool ,
400
+ indices_dtype : torch .dtype ,
401
+ offsets_dtype : torch .dtype ,
402
+ lengths_dtype : torch .dtype ,
399
403
) -> List [KeyedJaggedTensor ]:
400
404
local_kjts = []
401
405
keys = list (feature_num_embeddings .keys ())
406
+
402
407
for rank in range (world_size ):
403
408
lengths_per_rank_per_feature [rank ] = {}
404
409
values_per_rank_per_feature [rank ] = {}
405
410
strides_per_rank_per_feature [rank ] = {}
406
411
inverse_indices_per_rank_per_feature [rank ] = {}
412
+
407
413
if weights_per_rank_per_feature is not None :
408
414
weights_per_rank_per_feature [rank ] = {}
409
415
410
416
for key , num_embeddings in feature_num_embeddings .items ():
411
417
batch_size = random .randint (1 , average_batch_size * dedup_factor - 1 )
412
- lengths = torch .randint (low = 0 , high = 5 , size = (batch_size ,))
418
+ lengths = torch .randint (
419
+ low = 0 , high = 5 , size = (batch_size ,), dtype = lengths_dtype
420
+ )
413
421
lengths_per_rank_per_feature [rank ][key ] = lengths
414
422
lengths_sum = sum (lengths .tolist ())
415
- values = torch .randint (0 , num_embeddings , (lengths_sum ,))
423
+ values = torch .randint (
424
+ 0 , num_embeddings , (lengths_sum ,), dtype = indices_dtype
425
+ )
416
426
values_per_rank_per_feature [rank ][key ] = values
417
427
if weights_per_rank_per_feature is not None :
418
428
weights_per_rank_per_feature [rank ][key ] = torch .rand (lengths_sum )
419
429
strides_per_rank_per_feature [rank ][key ] = batch_size
420
430
inverse_indices_per_rank_per_feature [rank ][key ] = torch .randint (
421
- 0 , batch_size , (dedup_factor * average_batch_size ,)
431
+ 0 ,
432
+ batch_size ,
433
+ (dedup_factor * average_batch_size ,),
434
+ dtype = indices_dtype ,
422
435
)
423
436
424
437
values = torch .cat (list (values_per_rank_per_feature [rank ].values ()))
@@ -428,23 +441,40 @@ def _generate_variable_batch_local_features(
428
441
if weights_per_rank_per_feature is not None
429
442
else None
430
443
)
431
- stride_per_key_per_rank = [
432
- [stride ] for stride in strides_per_rank_per_feature [rank ].values ()
433
- ]
434
- inverse_indices = (
435
- keys ,
436
- torch .stack (list (inverse_indices_per_rank_per_feature [rank ].values ())),
437
- )
438
- local_kjts .append (
439
- KeyedJaggedTensor (
440
- keys = keys ,
441
- values = values ,
442
- lengths = lengths ,
443
- weights = weights ,
444
- stride_per_key_per_rank = stride_per_key_per_rank ,
445
- inverse_indices = inverse_indices ,
444
+
445
+ if use_offsets :
446
+ offsets = torch .cat (
447
+ [torch .tensor ([0 ], dtype = offsets_dtype ), lengths .cumsum (0 )]
446
448
)
447
- )
449
+ local_kjts .append (
450
+ KeyedJaggedTensor (
451
+ keys = keys ,
452
+ values = values ,
453
+ offsets = offsets ,
454
+ weights = weights ,
455
+ )
456
+ )
457
+ else :
458
+ stride_per_key_per_rank = [
459
+ [stride ] for stride in strides_per_rank_per_feature [rank ].values ()
460
+ ]
461
+ inverse_indices = (
462
+ keys ,
463
+ torch .stack (
464
+ list (inverse_indices_per_rank_per_feature [rank ].values ())
465
+ ),
466
+ )
467
+ local_kjts .append (
468
+ KeyedJaggedTensor (
469
+ keys = keys ,
470
+ values = values ,
471
+ lengths = lengths ,
472
+ weights = weights ,
473
+ stride_per_key_per_rank = stride_per_key_per_rank ,
474
+ inverse_indices = inverse_indices ,
475
+ )
476
+ )
477
+
448
478
return local_kjts
449
479
450
480
@staticmethod
@@ -457,6 +487,10 @@ def _generate_variable_batch_global_features(
457
487
strides_per_rank_per_feature : Dict [int , Dict [str , int ]],
458
488
inverse_indices_per_rank_per_feature : Dict [int , Dict [str , torch .Tensor ]],
459
489
weights_per_rank_per_feature : Optional [Dict [int , Dict [str , torch .Tensor ]]],
490
+ use_offsets : bool ,
491
+ indices_dtype : torch .dtype ,
492
+ offsets_dtype : torch .dtype ,
493
+ lengths_dtype : torch .dtype ,
460
494
) -> KeyedJaggedTensor :
461
495
global_values = []
462
496
global_lengths = []
@@ -476,31 +510,41 @@ def _generate_variable_batch_global_features(
476
510
inverse_indices_per_feature_per_rank .append (
477
511
inverse_indices_per_rank_per_feature [rank ][key ]
478
512
)
513
+
479
514
global_stride_per_key_per_rank .append ([sum_stride ])
480
515
481
516
inverse_indices_list : List [torch .Tensor ] = []
517
+
482
518
for key in keys :
483
519
accum_batch_size = 0
484
520
inverse_indices = []
521
+
485
522
for rank in range (world_size ):
486
523
inverse_indices .append (
487
524
inverse_indices_per_rank_per_feature [rank ][key ] + accum_batch_size
488
525
)
489
526
accum_batch_size += strides_per_rank_per_feature [rank ][key ]
527
+
490
528
inverse_indices_list .append (torch .cat (inverse_indices ))
529
+
491
530
global_inverse_indices = (keys , torch .stack (inverse_indices_list ))
492
531
493
532
if global_constant_batch :
494
533
global_offsets = []
534
+
495
535
for length in global_lengths :
496
536
global_offsets .append (_to_offsets (length ))
537
+
497
538
reindexed_lengths = []
539
+
498
540
for length , indices in zip (
499
541
global_lengths , inverse_indices_per_feature_per_rank
500
542
):
501
543
reindexed_lengths .append (torch .index_select (length , 0 , indices ))
544
+
502
545
lengths = torch .cat (reindexed_lengths )
503
546
reindexed_values , reindexed_weights = [], []
547
+
504
548
for i , (values , offsets , indices ) in enumerate (
505
549
zip (global_values , global_offsets , inverse_indices_per_feature_per_rank )
506
550
):
@@ -510,25 +554,40 @@ def _generate_variable_batch_global_features(
510
554
reindexed_weights .append (
511
555
global_weights [i ][offsets [idx ] : offsets [idx + 1 ]]
512
556
)
557
+
513
558
values = torch .cat (reindexed_values )
514
559
weights = (
515
560
torch .cat (reindexed_weights ) if global_weights is not None else None
516
561
)
517
562
global_stride_per_key_per_rank = None
518
563
global_inverse_indices = None
564
+
519
565
else :
520
566
values = torch .cat (global_values )
521
567
lengths = torch .cat (global_lengths )
522
568
weights = torch .cat (global_weights ) if global_weights is not None else None
523
569
524
- return KeyedJaggedTensor (
525
- keys = keys ,
526
- values = values ,
527
- lengths = lengths ,
528
- weights = weights ,
529
- stride_per_key_per_rank = global_stride_per_key_per_rank ,
530
- inverse_indices = global_inverse_indices ,
531
- )
570
+ if use_offsets :
571
+ offsets = torch .cat (
572
+ [torch .tensor ([0 ], dtype = offsets_dtype ), lengths .cumsum (0 )]
573
+ )
574
+ return KeyedJaggedTensor (
575
+ keys = keys ,
576
+ values = values ,
577
+ offsets = offsets ,
578
+ weights = weights ,
579
+ stride_per_key_per_rank = global_stride_per_key_per_rank ,
580
+ inverse_indices = global_inverse_indices ,
581
+ )
582
+ else :
583
+ return KeyedJaggedTensor (
584
+ keys = keys ,
585
+ values = values ,
586
+ lengths = lengths ,
587
+ weights = weights ,
588
+ stride_per_key_per_rank = global_stride_per_key_per_rank ,
589
+ inverse_indices = global_inverse_indices ,
590
+ )
532
591
533
592
@staticmethod
534
593
def _generate_variable_batch_features (
@@ -539,11 +598,17 @@ def _generate_variable_batch_features(
539
598
world_size : int ,
540
599
dedup_factor : int ,
541
600
global_constant_batch : bool ,
601
+ use_offsets : bool ,
602
+ indices_dtype : torch .dtype ,
603
+ offsets_dtype : torch .dtype ,
604
+ lengths_dtype : torch .dtype ,
542
605
) -> Tuple [KeyedJaggedTensor , List [KeyedJaggedTensor ]]:
543
606
is_weighted = (
544
607
True if tables and getattr (tables [0 ], "is_weighted" , False ) else False
545
608
)
609
+
546
610
feature_num_embeddings = {}
611
+
547
612
for table in tables :
548
613
for feature_name in table .feature_names :
549
614
feature_num_embeddings [feature_name ] = (
@@ -553,33 +618,42 @@ def _generate_variable_batch_features(
553
618
)
554
619
555
620
local_kjts = []
621
+
556
622
values_per_rank_per_feature = {}
557
623
lengths_per_rank_per_feature = {}
558
624
strides_per_rank_per_feature = {}
559
625
inverse_indices_per_rank_per_feature = {}
560
626
weights_per_rank_per_feature = {} if is_weighted else None
561
627
562
628
local_kjts = ModelInput ._generate_variable_batch_local_features (
563
- feature_num_embeddings ,
564
- average_batch_size ,
565
- world_size ,
566
- dedup_factor ,
567
- values_per_rank_per_feature ,
568
- lengths_per_rank_per_feature ,
569
- strides_per_rank_per_feature ,
570
- inverse_indices_per_rank_per_feature ,
571
- weights_per_rank_per_feature ,
629
+ feature_num_embeddings = feature_num_embeddings ,
630
+ average_batch_size = average_batch_size ,
631
+ world_size = world_size ,
632
+ dedup_factor = dedup_factor ,
633
+ values_per_rank_per_feature = values_per_rank_per_feature ,
634
+ lengths_per_rank_per_feature = lengths_per_rank_per_feature ,
635
+ strides_per_rank_per_feature = strides_per_rank_per_feature ,
636
+ inverse_indices_per_rank_per_feature = inverse_indices_per_rank_per_feature ,
637
+ weights_per_rank_per_feature = weights_per_rank_per_feature ,
638
+ use_offsets = use_offsets ,
639
+ indices_dtype = indices_dtype ,
640
+ offsets_dtype = offsets_dtype ,
641
+ lengths_dtype = lengths_dtype ,
572
642
)
573
643
574
644
global_kjt = ModelInput ._generate_variable_batch_global_features (
575
- list (feature_num_embeddings .keys ()),
576
- world_size ,
577
- global_constant_batch ,
578
- values_per_rank_per_feature ,
579
- lengths_per_rank_per_feature ,
580
- strides_per_rank_per_feature ,
581
- inverse_indices_per_rank_per_feature ,
582
- weights_per_rank_per_feature ,
645
+ keys = list (feature_num_embeddings .keys ()),
646
+ world_size = world_size ,
647
+ global_constant_batch = global_constant_batch ,
648
+ values_per_rank_per_feature = values_per_rank_per_feature ,
649
+ lengths_per_rank_per_feature = lengths_per_rank_per_feature ,
650
+ strides_per_rank_per_feature = strides_per_rank_per_feature ,
651
+ inverse_indices_per_rank_per_feature = inverse_indices_per_rank_per_feature ,
652
+ weights_per_rank_per_feature = weights_per_rank_per_feature ,
653
+ use_offsets = use_offsets ,
654
+ indices_dtype = indices_dtype ,
655
+ offsets_dtype = offsets_dtype ,
656
+ lengths_dtype = lengths_dtype ,
583
657
)
584
658
585
659
return (global_kjt , local_kjts )
@@ -601,30 +675,51 @@ def generate_variable_batch_input(
601
675
] = None ,
602
676
pooling_avg : int = 10 ,
603
677
global_constant_batch : bool = False ,
678
+ use_offsets : bool = False ,
679
+ indices_dtype : torch .dtype = torch .int64 ,
680
+ offsets_dtype : torch .dtype = torch .int64 ,
681
+ lengths_dtype : torch .dtype = torch .int64 ,
604
682
) -> Tuple ["ModelInput" , List ["ModelInput" ]]:
605
683
torch .manual_seed (100 )
606
684
random .seed (100 )
607
685
dedup_factor = 2
686
+
608
687
global_kjt , local_kjts = ModelInput ._generate_variable_batch_features (
609
- tables , average_batch_size , world_size , dedup_factor , global_constant_batch
688
+ tables = tables ,
689
+ average_batch_size = average_batch_size ,
690
+ world_size = world_size ,
691
+ dedup_factor = dedup_factor ,
692
+ global_constant_batch = global_constant_batch ,
693
+ use_offsets = use_offsets ,
694
+ indices_dtype = indices_dtype ,
695
+ offsets_dtype = offsets_dtype ,
696
+ lengths_dtype = lengths_dtype ,
610
697
)
698
+
611
699
if weighted_tables :
612
700
global_score_kjt , local_score_kjts = (
613
701
ModelInput ._generate_variable_batch_features (
614
- weighted_tables ,
615
- average_batch_size ,
616
- world_size ,
617
- dedup_factor ,
618
- global_constant_batch ,
702
+ tables = weighted_tables ,
703
+ average_batch_size = average_batch_size ,
704
+ world_size = world_size ,
705
+ dedup_factor = dedup_factor ,
706
+ global_constant_batch = global_constant_batch ,
707
+ use_offsets = use_offsets ,
708
+ indices_dtype = indices_dtype ,
709
+ offsets_dtype = offsets_dtype ,
710
+ lengths_dtype = lengths_dtype ,
619
711
)
620
712
)
621
713
else :
622
714
global_score_kjt , local_score_kjts = None , []
715
+
623
716
global_float = torch .rand (
624
717
(dedup_factor * average_batch_size * world_size , num_float_features )
625
718
)
719
+
626
720
local_model_input = []
627
721
label_per_rank = []
722
+
628
723
for rank in range (world_size ):
629
724
label_per_rank .append (torch .rand (dedup_factor * average_batch_size ))
630
725
local_float = global_float [
@@ -644,12 +739,14 @@ def generate_variable_batch_input(
644
739
float_features = local_float ,
645
740
),
646
741
)
742
+
647
743
global_model_input = ModelInput (
648
744
idlist_features = global_kjt ,
649
745
idscore_features = global_score_kjt ,
650
746
label = torch .cat (label_per_rank ),
651
747
float_features = global_float ,
652
748
)
749
+
653
750
return (global_model_input , local_model_input )
654
751
655
752
def to (self , device : torch .device , non_blocking : bool = False ) -> "ModelInput" :
0 commit comments