27
27
28
28
import torch
29
29
from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
30
+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
31
+ DenseTableBatchedEmbeddingBagsCodegen ,
32
+ )
30
33
from tensordict import TensorDict
31
34
from torch import distributed as dist , nn , Tensor
32
35
from torch .autograd .profiler import record_function
50
53
)
51
54
from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
52
55
from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
56
+ from torchrec .distributed .sharding .dynamic_sharding import (
57
+ shards_all_to_all ,
58
+ update_state_dict_post_resharding ,
59
+ )
53
60
from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
54
61
from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
55
62
from torchrec .distributed .sharding .tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635
642
self ._env = env
636
643
# output parameters as DTensor in state dict
637
644
self ._output_dtensor : bool = env .output_dtensor
638
-
639
- sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
640
- module ,
641
- table_name_to_parameter_sharding ,
642
- "embedding_bags." ,
643
- fused_params ,
645
+ self .sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = (
646
+ create_sharding_infos_by_sharding (
647
+ module ,
648
+ table_name_to_parameter_sharding ,
649
+ "embedding_bags." ,
650
+ fused_params ,
651
+ )
652
+ )
653
+ self ._sharding_types : List [str ] = list (
654
+ self .sharding_type_to_sharding_infos .keys ()
644
655
)
645
- self ._sharding_types : List [str ] = list (sharding_type_to_sharding_infos .keys ())
646
656
self ._embedding_shardings : List [
647
657
EmbeddingSharding [
648
658
EmbeddingShardingContext ,
@@ -658,7 +668,7 @@ def __init__(
658
668
permute_embeddings = True ,
659
669
qcomm_codecs_registry = self .qcomm_codecs_registry ,
660
670
)
661
- for embedding_configs in sharding_type_to_sharding_infos .values ()
671
+ for embedding_configs in self . sharding_type_to_sharding_infos .values ()
662
672
]
663
673
664
674
self ._is_weighted : bool = module .is_weighted ()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833
843
lookup = lookup .module
834
844
lookup .purge ()
835
845
836
- def _initialize_torch_state (self ) -> None : # noqa
846
+ def _initialize_torch_state (self , skip_registering : bool = False ) -> None : # noqa
837
847
"""
838
848
This provides consistency between this class and the EmbeddingBagCollection's
839
849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
1063
1073
destination_key = f"{ prefix } embedding_bags.{ table_name } .weight"
1064
1074
destination [destination_key ] = sharded_kvtensor
1065
1075
1066
- self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1067
- self ._register_state_dict_hook (post_state_dict_hook )
1068
- self ._register_load_state_dict_pre_hook (
1069
- self ._pre_load_state_dict_hook , with_module = True
1070
- )
1076
+ if not skip_registering :
1077
+ self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
1078
+ self ._register_state_dict_hook (post_state_dict_hook )
1079
+ self ._register_load_state_dict_pre_hook (
1080
+ self ._pre_load_state_dict_hook , with_module = True
1081
+ )
1071
1082
self .reset_parameters ()
1072
1083
1073
1084
def reset_parameters (self ) -> None :
@@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None:
1164
1175
self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
1165
1176
embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1166
1177
self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
1178
+
1167
1179
embedding_shard_offsets : List [int ] = [
1168
1180
meta .shard_offsets [1 ] if meta is not None else 0
1169
1181
for meta in embedding_shard_metadata
@@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None:
1179
1191
embedding_shard_offsets [i ],
1180
1192
),
1181
1193
)
1194
+
1195
+ self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1196
+ self ._uncombined_embedding_dims , permute_indices , self ._device
1197
+ )
1198
+
1199
+ def _update_output_dist (self ) -> None :
1200
+ embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
1201
+ # TODO: Optimize to only go through embedding shardings with new ranks
1202
+ self ._output_dists : List [nn .Module ] = []
1203
+ self ._embedding_names : List [str ] = []
1204
+ for sharding in self ._embedding_shardings :
1205
+ # TODO: if sharding type of table completely changes, need to regenerate everything
1206
+ self ._embedding_names .extend (sharding .embedding_names ())
1207
+ self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
1208
+ embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
1209
+
1210
+ embedding_shard_offsets : List [int ] = [
1211
+ meta .shard_offsets [1 ] if meta is not None else 0
1212
+ for meta in embedding_shard_metadata
1213
+ ]
1214
+ embedding_name_order : Dict [str , int ] = {}
1215
+ for i , name in enumerate (self ._uncombined_embedding_names ):
1216
+ embedding_name_order .setdefault (name , i )
1217
+
1218
+ permute_indices = sorted (
1219
+ range (len (self ._uncombined_embedding_names )),
1220
+ key = lambda i : (
1221
+ embedding_name_order [self ._uncombined_embedding_names [i ]],
1222
+ embedding_shard_offsets [i ],
1223
+ ),
1224
+ )
1225
+
1182
1226
self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1183
1227
self ._uncombined_embedding_dims , permute_indices , self ._device
1184
1228
)
@@ -1396,13 +1440,119 @@ def compute_and_output_dist(
1396
1440
1397
1441
return awaitable
1398
1442
1443
+ def update_shards (
1444
+ self ,
1445
+ changed_sharding_params : Dict [str , ParameterSharding ], # NOTE: only delta
1446
+ env : ShardingEnv ,
1447
+ device : Optional [torch .device ],
1448
+ ) -> None :
1449
+ """
1450
+ Update shards for this module based on the changed_sharding_params. This will:
1451
+ 1. Move current lookup tensors to CPU
1452
+ 2. Purge lookups
1453
+ 3. Call shards_all_2_all containing collective to redistribute tensors
1454
+ 4. Update state_dict and other attributes to reflect new placements and shards
1455
+ 5. Create new lookups, and load in updated state_dict
1456
+
1457
+ Args:
1458
+ changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1459
+ table names to their new parameter sharding configs. This should only
1460
+ contain shards/table names that need to be moved.
1461
+ env (ShardingEnv): The sharding environment for the module.
1462
+ device (Optional[torch.device]): The device to place the updated module on.
1463
+ """
1464
+
1465
+ if env .output_dtensor :
1466
+ raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1467
+ return
1468
+
1469
+ current_state = self .state_dict ()
1470
+ # TODO: Save Optimizers
1471
+
1472
+ saved_weights = {}
1473
+ # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1474
+ for i , lookup in enumerate (self ._lookups ):
1475
+ for attribute , tbe_module in lookup .named_modules ():
1476
+ if type (tbe_module ) is DenseTableBatchedEmbeddingBagsCodegen :
1477
+ saved_weights [str (i ) + "." + attribute ] = tbe_module .weights .cpu ()
1478
+ # Note: lookup.purge should delete tbe_module and weights
1479
+ # del tbe_module.weights
1480
+ # del tbe_module
1481
+ # pyre-ignore
1482
+ lookup .purge ()
1483
+
1484
+ # Deleting all lookups
1485
+ self ._lookups .clear ()
1486
+
1487
+ local_shard_names_by_src_rank , local_output_tensor = shards_all_to_all (
1488
+ module = self ,
1489
+ state_dict = current_state ,
1490
+ device = device , # pyre-ignore
1491
+ changed_sharding_params = changed_sharding_params ,
1492
+ env = env ,
1493
+ extend_shard_name = self .extend_shard_name ,
1494
+ )
1495
+
1496
+ current_state = update_state_dict_post_resharding (
1497
+ state_dict = current_state ,
1498
+ shard_names_by_src_rank = local_shard_names_by_src_rank ,
1499
+ output_tensor = local_output_tensor ,
1500
+ new_sharding_params = changed_sharding_params ,
1501
+ curr_rank = dist .get_rank (),
1502
+ extend_shard_name = self .extend_shard_name ,
1503
+ )
1504
+
1505
+ for name , param in changed_sharding_params .items ():
1506
+ self .module_sharding_plan [name ] = param
1507
+ # TODO: Support detecting old sharding type when sharding type is changing
1508
+ for sharding_info in self .sharding_type_to_sharding_infos [
1509
+ param .sharding_type
1510
+ ]:
1511
+ if sharding_info .embedding_config .name == name :
1512
+ sharding_info .param_sharding = param
1513
+
1514
+ self ._sharding_types : List [str ] = list (
1515
+ self .sharding_type_to_sharding_infos .keys ()
1516
+ )
1517
+ # TODO: Optimize to update only the changed embedding shardings
1518
+ self ._embedding_shardings : List [
1519
+ EmbeddingSharding [
1520
+ EmbeddingShardingContext ,
1521
+ KeyedJaggedTensor ,
1522
+ torch .Tensor ,
1523
+ torch .Tensor ,
1524
+ ]
1525
+ ] = [
1526
+ create_embedding_bag_sharding (
1527
+ embedding_configs ,
1528
+ env ,
1529
+ device ,
1530
+ permute_embeddings = True ,
1531
+ qcomm_codecs_registry = self .qcomm_codecs_registry ,
1532
+ )
1533
+ for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1534
+ ]
1535
+
1536
+ self ._create_lookups ()
1537
+ self ._update_output_dist ()
1538
+
1539
+ if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1540
+ self ._initialize_torch_state (skip_registering = True )
1541
+
1542
+ self .load_state_dict (current_state )
1543
+ return
1544
+
1399
1545
@property
1400
1546
def fused_optimizer (self ) -> KeyedOptimizer :
1401
1547
return self ._optim
1402
1548
1403
1549
def create_context (self ) -> EmbeddingBagCollectionContext :
1404
1550
return EmbeddingBagCollectionContext ()
1405
1551
1552
+ @staticmethod
1553
+ def extend_shard_name (shard_name : str ) -> str :
1554
+ return f"embedding_bags.{ shard_name } .weight"
1555
+
1406
1556
1407
1557
class EmbeddingBagCollectionSharder (BaseEmbeddingSharder [EmbeddingBagCollection ]):
1408
1558
"""
@@ -1435,6 +1585,33 @@ def shardable_parameters(
1435
1585
for name , param in module .embedding_bags .named_parameters ()
1436
1586
}
1437
1587
1588
+ def reshard (
1589
+ self ,
1590
+ sharded_module : ShardedEmbeddingBagCollection ,
1591
+ changed_shard_to_params : Dict [str , ParameterSharding ],
1592
+ env : ShardingEnv ,
1593
+ device : Optional [torch .device ] = None ,
1594
+ ) -> ShardedEmbeddingBagCollection :
1595
+ """
1596
+ Updates the sharded module in place based on the changed_shard_to_params
1597
+ which contains the new ParameterSharding with different shard placements.
1598
+
1599
+ Args:
1600
+ sharded_module (ShardedEmbeddingBagCollection): The module to update
1601
+ changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1602
+ table names to their new parameter sharding configs. This should only
1603
+ contain shards/table names that need to be moved
1604
+ env (ShardingEnv): The sharding environment
1605
+ device (Optional[torch.device]): The device to place the updated module on
1606
+
1607
+ Returns:
1608
+ ShardedEmbeddingBagCollection: The updated sharded module
1609
+ """
1610
+
1611
+ if len (changed_shard_to_params ) > 0 :
1612
+ sharded_module .update_shards (changed_shard_to_params , env , device )
1613
+ return sharded_module
1614
+
1438
1615
@property
1439
1616
def module_type (self ) -> Type [EmbeddingBagCollection ]:
1440
1617
return EmbeddingBagCollection
0 commit comments