Skip to content

Commit d954720

Browse files
jvandebonfacebook-github-bot
authored andcommitted
Update list for supported MTIA sharding types (#2789)
Summary: Pull Request resolved: #2789 As title. Previous diffs in stack add unit tests for TWCW, TWRW, and GRID sharding. bypass-github-pytorch-ci-checks Reviewed By: kausv, nautsimon, egienvalue, rexyl Differential Revision: D69924446 fbshipit-source-id: cf9da78f1492e22d0f01ab327f4aeeeb5398cc86
1 parent 48fb199 commit d954720

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchrec/distributed/embedding_types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,16 @@ def __init__(
414414
self._fused_params = fused_params
415415

416416
def sharding_types(self, compute_device_type: str) -> List[str]:
417-
# For MTIA, sharding types are restricted to TW, CW.
417+
418418
if compute_device_type in {"mtia"}:
419-
return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
419+
return [
420+
ShardingType.TABLE_WISE.value,
421+
ShardingType.COLUMN_WISE.value,
422+
ShardingType.TABLE_COLUMN_WISE.value,
423+
ShardingType.ROW_WISE.value,
424+
ShardingType.TABLE_ROW_WISE.value,
425+
ShardingType.GRID_SHARD.value,
426+
]
420427

421428
types = [
422429
ShardingType.DATA_PARALLEL.value,

0 commit comments

Comments
 (0)