@@ -115,7 +115,8 @@ def __init__(
115
115
116
116
def _init_combined_embeddings (self ) -> None :
117
117
"""
118
- similar to CW sharding, but this time each CW shard is on a node and not rank
118
+ Initializes combined embeddings, similar to the CW sharding implementation,
119
+ but in this case the CW shard is treated on a per node basis and not per rank.
119
120
"""
120
121
embedding_names = []
121
122
for grouped_embedding_configs in self ._grouped_embedding_configs_per_node :
@@ -179,6 +180,17 @@ def _shard(
179
180
self ,
180
181
sharding_infos : List [EmbeddingShardingInfo ],
181
182
) -> List [List [ShardedEmbeddingTable ]]:
183
+ """
184
+ Shards the embedding tables.
185
+ This method takes the sharding infos and returns a list of lists of
186
+ sharded embedding tables, where each inner list represents the tables
187
+ for a specific rank.
188
+
189
+ Args:
190
+ sharding_infos (List[EmbeddingShardingInfo]): The sharding infos.
191
+ Returns:
192
+ List[List[ShardedEmbeddingTable]]: The sharded embedding tables.
193
+ """
182
194
world_size = self ._world_size
183
195
tables_per_rank : List [List [ShardedEmbeddingTable ]] = [
184
196
[] for i in range (world_size )
@@ -198,7 +210,7 @@ def _shard(
198
210
),
199
211
)
200
212
201
- # expectation is planner CW shards across a node, so each CW shard will have local_size num row shards
213
+ # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
202
214
# pyre-fixme [6]
203
215
for i , rank in enumerate (info .param_sharding .ranks ):
204
216
tables_per_rank [rank ].append (
@@ -212,7 +224,6 @@ def _shard(
212
224
pooling = info .embedding_config .pooling ,
213
225
is_weighted = info .embedding_config .is_weighted ,
214
226
has_feature_processor = info .embedding_config .has_feature_processor ,
215
- # sharding by row and col
216
227
local_rows = shards [i ].shard_sizes [0 ],
217
228
local_cols = shards [i ].shard_sizes [1 ],
218
229
compute_kernel = EmbeddingComputeKernel (
@@ -420,7 +431,7 @@ class GridPooledEmbeddingSharding(
420
431
]
421
432
):
422
433
"""
423
- Shards embedding bags table- wise then row- wise.
434
+ Shards embedding bags into column wise shards and shards each CW shard table wise row wise within a node
424
435
"""
425
436
426
437
def create_input_dist (
0 commit comments