Skip to content

Commit 2268800

Browse files
aporialiaofacebook-github-bot
authored andcommitted
1/n Dynamic Sharding API + Test for EBC, TW, ShardedTensor (#2852)
Summary: Pull Request resolved: #2852 Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs. What's added here: 1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection` 2. Util functions for dynamic sharding - these are used by the `update_shards` API: 1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight` 2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params` 3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from various: `world_sizes`, `num_tables`, `data_types`. 1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 Future work items (features not yet supported in this diff): * CW, RW, and many other sharding types * Optimizer saving * DTensor implementation Reviewed By: iamzainhuda Differential Revision: D69095169 fbshipit-source-id: f3cb3081ef7c6ca12ca01f4a7e340e3520ee6cab
1 parent ad42b3e commit 2268800

File tree

3 files changed

+818
-14
lines changed

3 files changed

+818
-14
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 191 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
31+
DenseTableBatchedEmbeddingBagsCodegen,
32+
)
3033
from tensordict import TensorDict
3134
from torch import distributed as dist, nn, Tensor
3235
from torch.autograd.profiler import record_function
@@ -50,6 +53,10 @@
5053
)
5154
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5255
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
56+
from torchrec.distributed.sharding.dynamic_sharding import (
57+
shards_all_to_all,
58+
update_state_dict_post_resharding,
59+
)
5360
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
5461
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
5562
from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding
@@ -635,14 +642,17 @@ def __init__(
635642
self._env = env
636643
# output parameters as DTensor in state dict
637644
self._output_dtensor: bool = env.output_dtensor
638-
639-
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
640-
module,
641-
table_name_to_parameter_sharding,
642-
"embedding_bags.",
643-
fused_params,
645+
self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
646+
create_sharding_infos_by_sharding(
647+
module,
648+
table_name_to_parameter_sharding,
649+
"embedding_bags.",
650+
fused_params,
651+
)
652+
)
653+
self._sharding_types: List[str] = list(
654+
self.sharding_type_to_sharding_infos.keys()
644655
)
645-
self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())
646656
self._embedding_shardings: List[
647657
EmbeddingSharding[
648658
EmbeddingShardingContext,
@@ -658,7 +668,7 @@ def __init__(
658668
permute_embeddings=True,
659669
qcomm_codecs_registry=self.qcomm_codecs_registry,
660670
)
661-
for embedding_configs in sharding_type_to_sharding_infos.values()
671+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
662672
]
663673

664674
self._is_weighted: bool = module.is_weighted()
@@ -833,7 +843,7 @@ def _pre_load_state_dict_hook(
833843
lookup = lookup.module
834844
lookup.purge()
835845

836-
def _initialize_torch_state(self) -> None: # noqa
846+
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
837847
"""
838848
This provides consistency between this class and the EmbeddingBagCollection's
839849
nn.Module API calls (state_dict, named_modules, etc)
@@ -1063,11 +1073,12 @@ def post_state_dict_hook(
10631073
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
10641074
destination[destination_key] = sharded_kvtensor
10651075

1066-
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1067-
self._register_state_dict_hook(post_state_dict_hook)
1068-
self._register_load_state_dict_pre_hook(
1069-
self._pre_load_state_dict_hook, with_module=True
1070-
)
1076+
if not skip_registering:
1077+
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
1078+
self._register_state_dict_hook(post_state_dict_hook)
1079+
self._register_load_state_dict_pre_hook(
1080+
self._pre_load_state_dict_hook, with_module=True
1081+
)
10711082
self.reset_parameters()
10721083

10731084
def reset_parameters(self) -> None:
@@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None:
11641175
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
11651176
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
11661177
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)
1178+
11671179
embedding_shard_offsets: List[int] = [
11681180
meta.shard_offsets[1] if meta is not None else 0
11691181
for meta in embedding_shard_metadata
@@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None:
11791191
embedding_shard_offsets[i],
11801192
),
11811193
)
1194+
1195+
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
1196+
self._uncombined_embedding_dims, permute_indices, self._device
1197+
)
1198+
1199+
def _update_output_dist(self) -> None:
1200+
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
1201+
# TODO: Optimize to only go through embedding shardings with new ranks
1202+
self._output_dists: List[nn.Module] = []
1203+
self._embedding_names: List[str] = []
1204+
for sharding in self._embedding_shardings:
1205+
# TODO: if sharding type of table completely changes, need to regenerate everything
1206+
self._embedding_names.extend(sharding.embedding_names())
1207+
self._output_dists.append(sharding.create_output_dist(device=self._device))
1208+
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
1209+
1210+
embedding_shard_offsets: List[int] = [
1211+
meta.shard_offsets[1] if meta is not None else 0
1212+
for meta in embedding_shard_metadata
1213+
]
1214+
embedding_name_order: Dict[str, int] = {}
1215+
for i, name in enumerate(self._uncombined_embedding_names):
1216+
embedding_name_order.setdefault(name, i)
1217+
1218+
permute_indices = sorted(
1219+
range(len(self._uncombined_embedding_names)),
1220+
key=lambda i: (
1221+
embedding_name_order[self._uncombined_embedding_names[i]],
1222+
embedding_shard_offsets[i],
1223+
),
1224+
)
1225+
11821226
self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings(
11831227
self._uncombined_embedding_dims, permute_indices, self._device
11841228
)
@@ -1396,13 +1440,119 @@ def compute_and_output_dist(
13961440

13971441
return awaitable
13981442

1443+
def update_shards(
1444+
self,
1445+
changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta
1446+
env: ShardingEnv,
1447+
device: Optional[torch.device],
1448+
) -> None:
1449+
"""
1450+
Update shards for this module based on the changed_sharding_params. This will:
1451+
1. Move current lookup tensors to CPU
1452+
2. Purge lookups
1453+
3. Call shards_all_2_all containing collective to redistribute tensors
1454+
4. Update state_dict and other attributes to reflect new placements and shards
1455+
5. Create new lookups, and load in updated state_dict
1456+
1457+
Args:
1458+
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1459+
table names to their new parameter sharding configs. This should only
1460+
contain shards/table names that need to be moved.
1461+
env (ShardingEnv): The sharding environment for the module.
1462+
device (Optional[torch.device]): The device to place the updated module on.
1463+
"""
1464+
1465+
if env.output_dtensor:
1466+
raise RuntimeError("We do not yet support DTensor for resharding yet")
1467+
return
1468+
1469+
current_state = self.state_dict()
1470+
# TODO: Save Optimizers
1471+
1472+
saved_weights = {}
1473+
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1474+
for i, lookup in enumerate(self._lookups):
1475+
for attribute, tbe_module in lookup.named_modules():
1476+
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1477+
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1478+
# Note: lookup.purge should delete tbe_module and weights
1479+
# del tbe_module.weights
1480+
# del tbe_module
1481+
# pyre-ignore
1482+
lookup.purge()
1483+
1484+
# Deleting all lookups
1485+
self._lookups.clear()
1486+
1487+
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
1488+
module=self,
1489+
state_dict=current_state,
1490+
device=device, # pyre-ignore
1491+
changed_sharding_params=changed_sharding_params,
1492+
env=env,
1493+
extend_shard_name=self.extend_shard_name,
1494+
)
1495+
1496+
current_state = update_state_dict_post_resharding(
1497+
state_dict=current_state,
1498+
shard_names_by_src_rank=local_shard_names_by_src_rank,
1499+
output_tensor=local_output_tensor,
1500+
new_sharding_params=changed_sharding_params,
1501+
curr_rank=dist.get_rank(),
1502+
extend_shard_name=self.extend_shard_name,
1503+
)
1504+
1505+
for name, param in changed_sharding_params.items():
1506+
self.module_sharding_plan[name] = param
1507+
# TODO: Support detecting old sharding type when sharding type is changing
1508+
for sharding_info in self.sharding_type_to_sharding_infos[
1509+
param.sharding_type
1510+
]:
1511+
if sharding_info.embedding_config.name == name:
1512+
sharding_info.param_sharding = param
1513+
1514+
self._sharding_types: List[str] = list(
1515+
self.sharding_type_to_sharding_infos.keys()
1516+
)
1517+
# TODO: Optimize to update only the changed embedding shardings
1518+
self._embedding_shardings: List[
1519+
EmbeddingSharding[
1520+
EmbeddingShardingContext,
1521+
KeyedJaggedTensor,
1522+
torch.Tensor,
1523+
torch.Tensor,
1524+
]
1525+
] = [
1526+
create_embedding_bag_sharding(
1527+
embedding_configs,
1528+
env,
1529+
device,
1530+
permute_embeddings=True,
1531+
qcomm_codecs_registry=self.qcomm_codecs_registry,
1532+
)
1533+
for embedding_configs in self.sharding_type_to_sharding_infos.values()
1534+
]
1535+
1536+
self._create_lookups()
1537+
self._update_output_dist()
1538+
1539+
if env.process_group and dist.get_backend(env.process_group) != "fake":
1540+
self._initialize_torch_state(skip_registering=True)
1541+
1542+
self.load_state_dict(current_state)
1543+
return
1544+
13991545
@property
14001546
def fused_optimizer(self) -> KeyedOptimizer:
14011547
return self._optim
14021548

14031549
def create_context(self) -> EmbeddingBagCollectionContext:
14041550
return EmbeddingBagCollectionContext()
14051551

1552+
@staticmethod
1553+
def extend_shard_name(shard_name: str) -> str:
1554+
return f"embedding_bags.{shard_name}.weight"
1555+
14061556

14071557
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
14081558
"""
@@ -1435,6 +1585,33 @@ def shardable_parameters(
14351585
for name, param in module.embedding_bags.named_parameters()
14361586
}
14371587

1588+
def reshard(
1589+
self,
1590+
sharded_module: ShardedEmbeddingBagCollection,
1591+
changed_shard_to_params: Dict[str, ParameterSharding],
1592+
env: ShardingEnv,
1593+
device: Optional[torch.device] = None,
1594+
) -> ShardedEmbeddingBagCollection:
1595+
"""
1596+
Updates the sharded module in place based on the changed_shard_to_params
1597+
which contains the new ParameterSharding with different shard placements.
1598+
1599+
Args:
1600+
sharded_module (ShardedEmbeddingBagCollection): The module to update
1601+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
1602+
table names to their new parameter sharding configs. This should only
1603+
contain shards/table names that need to be moved
1604+
env (ShardingEnv): The sharding environment
1605+
device (Optional[torch.device]): The device to place the updated module on
1606+
1607+
Returns:
1608+
ShardedEmbeddingBagCollection: The updated sharded module
1609+
"""
1610+
1611+
if len(changed_shard_to_params) > 0:
1612+
sharded_module.update_shards(changed_shard_to_params, env, device)
1613+
return sharded_module
1614+
14381615
@property
14391616
def module_type(self) -> Type[EmbeddingBagCollection]:
14401617
return EmbeddingBagCollection

0 commit comments

Comments
 (0)