Skip to content

Commit b4d6a40

Browse files
Raahul Kalyaan Jakkameta-codesync[bot]
authored andcommitted
torchrec related changes for APF Integration (#3648)
Summary: X-link: pytorch/FBGEMM#5286 Pull Request resolved: #3648 X-link: https://github.com/facebookresearch/FBGEMM/pull/2279 **Overview** This diff introduces key changes to enable APF compatibility and column-wise sharding for SSD Offloading in TorchRec. The updates focus on sharding logic, configuration, and integration with SSD-backed embedding tables. **Key Changes** 1. Added support for column-wise sharding for SSD Offloading by creating a new sharding logic to device partitioning logic for column-wise sharding with SSD-backed tables, ensuring unique rank assignment and load balancing. 2. Configuration and Planner Updates Functions such as get_sharding_planner, get_sharding_plan, and related helpers now accept and process sharding constraints, including those specific to SSD Offloading. The sharding helpers can now handle key-value storage parameters, enabling SSD integration. Reviewed By: TroyGarden Differential Revision: D89049866 fbshipit-source-id: 25fb259c442c376ae6b07dd07222a20c2bece830
1 parent 29a42da commit b4d6a40

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,9 +2983,6 @@ def __init__(
29832983
assert (
29842984
len(config.embedding_tables) > 0
29852985
), "Expected to see at least one table in SSD TBE, but found 0."
2986-
assert (
2987-
len({table.embedding_dim for table in config.embedding_tables}) == 1
2988-
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
29892986
for table in config.embedding_tables:
29902987
assert table.local_cols % 4 == 0, (
29912988
f"table {table.name} has local_cols={table.local_cols} "

torchrec/distributed/planner/partitioners.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,23 @@ def partition(
309309
assert (
310310
len(sharding_option_group.sharding_options) == 1
311311
), f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}"
312-
self._device_partition(
313-
sharding_option_group.sharding_options[0],
314-
minheap_devices,
312+
313+
key_value = any(
314+
obj.compute_kernel == "key_value"
315+
for obj in sharding_option_group.sharding_options
316+
)
317+
318+
is_column_wise = any(
319+
opt.sharding_type == ShardingType.COLUMN_WISE.value
320+
for opt in sharding_option_group.sharding_options
315321
)
322+
323+
sharding_option = sharding_option_group.sharding_options[0]
324+
325+
if is_column_wise and key_value:
326+
self._column_wise_device_partition(sharding_option, minheap_devices)
327+
else:
328+
self._device_partition(sharding_option, minheap_devices)
316329
else:
317330
raise RuntimeError(
318331
f"Unexpected sharding option group {sharding_option_group}"
@@ -329,6 +342,81 @@ def _establish_minheap(
329342
heapq.heapify(minheap_devices)
330343
return minheap_devices
331344

345+
@classmethod
346+
def _column_wise_device_partition(
347+
cls,
348+
sharding_option: ShardingOption,
349+
minheap_devices: List[OrderedDeviceHardware],
350+
) -> None:
351+
"""
352+
Specialized device partitioning for COLUMN_WISE sharding that ensures:
353+
1. No multiple shards per rank (unique rank constraint)
354+
2. Load-balanced distribution using the existing greedy approach
355+
3. Efficient rank assignment with proper topology awareness
356+
"""
357+
num_shards = sharding_option.num_shards
358+
total_devices = len(minheap_devices)
359+
360+
if num_shards > total_devices:
361+
raise PlannerError(
362+
error_type=PlannerErrorType.PARTITION,
363+
message=f"COLUMN_WISE sharding requires num_shards ({num_shards}) <= num_devices ({total_devices})",
364+
)
365+
366+
used_ranks = set()
367+
368+
for shard in sharding_option.shards:
369+
found_device = False
370+
371+
# Find the best available device that hasn't been used for CWS
372+
available_devices = [
373+
od for od in minheap_devices if od.device.rank not in used_ranks
374+
]
375+
376+
if not available_devices:
377+
raise PlannerError(
378+
error_type=PlannerErrorType.PARTITION,
379+
message=(
380+
f"COLUMN_WISE partition failed. No available ranks for shard {shard} of table {sharding_option.name}. "
381+
f"Used ranks: {used_ranks}, total devices: {total_devices}"
382+
),
383+
)
384+
385+
# Sort available devices by the same criteria as the heap (load balancing)
386+
available_devices.sort(
387+
key=lambda od: (
388+
od.device.perf.total,
389+
od.device.rank % od.local_world_size,
390+
od.device.rank,
391+
)
392+
)
393+
394+
# Try to place shard on the least-loaded available device
395+
for ordered_device in available_devices:
396+
device = ordered_device.device
397+
storage = cast(Storage, shard.storage)
398+
399+
if storage.fits_in(device.storage):
400+
# Successfully place the shard
401+
shard.rank = device.rank
402+
device.storage -= storage
403+
device.perf += cast(Perf, shard.perf)
404+
used_ranks.add(device.rank)
405+
found_device = True
406+
break
407+
408+
if not found_device:
409+
raise PlannerError(
410+
error_type=PlannerErrorType.PARTITION,
411+
message=(
412+
f"COLUMN_WISE partition failed. Couldn't find a suitable rank for shard {shard} of table {sharding_option.name}. "
413+
f"Storage required: {shard.storage}, available devices: {len(available_devices)}"
414+
),
415+
)
416+
417+
# Re-heapify the devices after all changes
418+
heapq.heapify(minheap_devices)
419+
332420
@classmethod
333421
def _device_partition(
334422
cls,

0 commit comments

Comments
 (0)