Skip to content

Commit d317c0b

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
update grid sharding doc strings (#2488)
Summary: Pull Request resolved: #2488 tsia - updated docstrings to be more useful/accurate Differential Revision: D64423503 fbshipit-source-id: 7f1a92b259fb815571d5238a9238d447543cdb02
1 parent afd5726 commit d317c0b

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

torchrec/distributed/sharding/grid_sharding.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def __init__(
115115

116116
def _init_combined_embeddings(self) -> None:
117117
"""
118-
similar to CW sharding, but this time each CW shard is on a node and not rank
118+
Initializes combined embeddings, similar to the CW sharding implementation,
119+
but in this case the CW shard is treated on a per node basis and not per rank.
119120
"""
120121
embedding_names = []
121122
for grouped_embedding_configs in self._grouped_embedding_configs_per_node:
@@ -179,6 +180,17 @@ def _shard(
179180
self,
180181
sharding_infos: List[EmbeddingShardingInfo],
181182
) -> List[List[ShardedEmbeddingTable]]:
183+
"""
184+
Shards the embedding tables.
185+
This method takes the sharding infos and returns a list of lists of
186+
sharded embedding tables, where each inner list represents the tables
187+
for a specific rank.
188+
189+
Args:
190+
sharding_infos (List[EmbeddingShardingInfo]): The sharding infos.
191+
Returns:
192+
List[List[ShardedEmbeddingTable]]: The sharded embedding tables.
193+
"""
182194
world_size = self._world_size
183195
tables_per_rank: List[List[ShardedEmbeddingTable]] = [
184196
[] for i in range(world_size)
@@ -198,7 +210,7 @@ def _shard(
198210
),
199211
)
200212

201-
# expectation is planner CW shards across a node, so each CW shard will have local_size num row shards
213+
# Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
202214
# pyre-fixme [6]
203215
for i, rank in enumerate(info.param_sharding.ranks):
204216
tables_per_rank[rank].append(
@@ -212,7 +224,6 @@ def _shard(
212224
pooling=info.embedding_config.pooling,
213225
is_weighted=info.embedding_config.is_weighted,
214226
has_feature_processor=info.embedding_config.has_feature_processor,
215-
# sharding by row and col
216227
local_rows=shards[i].shard_sizes[0],
217228
local_cols=shards[i].shard_sizes[1],
218229
compute_kernel=EmbeddingComputeKernel(
@@ -420,7 +431,7 @@ class GridPooledEmbeddingSharding(
420431
]
421432
):
422433
"""
423-
Shards embedding bags table-wise then row-wise.
434+
Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node
424435
"""
425436

426437
def create_input_dist(

0 commit comments

Comments
 (0)