We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 48fb199 commit d954720Copy full SHA for d954720
torchrec/distributed/embedding_types.py
@@ -414,9 +414,16 @@ def __init__(
414
self._fused_params = fused_params
415
416
def sharding_types(self, compute_device_type: str) -> List[str]:
417
- # For MTIA, sharding types are restricted to TW, CW.
+
418
if compute_device_type in {"mtia"}:
419
- return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value]
+ return [
420
+ ShardingType.TABLE_WISE.value,
421
+ ShardingType.COLUMN_WISE.value,
422
+ ShardingType.TABLE_COLUMN_WISE.value,
423
+ ShardingType.ROW_WISE.value,
424
+ ShardingType.TABLE_ROW_WISE.value,
425
+ ShardingType.GRID_SHARD.value,
426
+ ]
427
428
types = [
429
ShardingType.DATA_PARALLEL.value,
0 commit comments