We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 55c663f commit 8c16e96Copy full SHA for 8c16e96
distributed/tensor_parallelism/two_d_parallel_example.py
@@ -83,7 +83,8 @@ def demo_2d(rank, args):
83
assert (
84
enable_2d_with_fsdp()
85
), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0"
86
- model = FSDP(model)
+ dp_pg = device_mesh.get_dim_groups()[0]
87
+ model = FSDP(model, process_group=dp_pg)
88
89
# Perform a num of iterations of forward/backward
90
# and optimizations for the sharded module.
@@ -94,7 +95,7 @@ def demo_2d(rank, args):
94
95
dp_rank = (
96
rank
97
if args.run_seq_parallel
- else dist.get_rank(device_mesh.get_dim_groups()[0])
98
+ else dist.get_rank(dp_pg)
99
)
100
torch.manual_seed(i + dp_rank)
101
inp = torch.rand(20, 10).cuda(rank)
0 commit comments