Skip to content

Commit 8c16e96

Browse files
authored
Fix 2D example to pass in data parallel pg (#1160)
1 parent 55c663f commit 8c16e96

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

distributed/tensor_parallelism/two_d_parallel_example.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def demo_2d(rank, args):
8383
assert (
8484
enable_2d_with_fsdp()
8585
), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0"
86-
model = FSDP(model)
86+
dp_pg = device_mesh.get_dim_groups()[0]
87+
model = FSDP(model, process_group=dp_pg)
8788

8889
# Perform a num of iterations of forward/backward
8990
# and optimizations for the sharded module.
@@ -94,7 +95,7 @@ def demo_2d(rank, args):
9495
dp_rank = (
9596
rank
9697
if args.run_seq_parallel
97-
else dist.get_rank(device_mesh.get_dim_groups()[0])
98+
else dist.get_rank(dp_pg)
9899
)
99100
torch.manual_seed(i + dp_rank)
100101
inp = torch.rand(20, 10).cuda(rank)

0 commit comments

Comments
 (0)