Skip to content

Commit a1ad9ed

Browse files
bowangbjbrianjo
andauthored
Update ddp_tutorial.rst (pytorch#1618)
Relax GPU count to support devices with 2 GPUs Co-authored-by: Brian Johnson <[email protected]>
1 parent efb676f commit a1ad9ed

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ either the application or the model ``forward()`` method.
265265
setup(rank, world_size)
266266
267267
# setup mp_model and devices for this process
268-
dev0 = rank * 2
269-
dev1 = rank * 2 + 1
268+
dev0 = (rank * 2) % world_size
269+
dev1 = (rank * 2 + 1) % world_size
270270
mp_model = ToyMpModel(dev0, dev1)
271271
ddp_mp_model = DDP(mp_model)
272272
@@ -285,9 +285,8 @@ either the application or the model ``forward()`` method.
285285
286286
if __name__ == "__main__":
287287
n_gpus = torch.cuda.device_count()
288-
if n_gpus < 8:
289-
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
290-
else:
291-
run_demo(demo_basic, 8)
292-
run_demo(demo_checkpoint, 8)
293-
run_demo(demo_model_parallel, 4)
288+
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
289+
world_size = n_gpus
290+
run_demo(demo_basic, world_size)
291+
run_demo(demo_checkpoint, world_size)
292+
run_demo(demo_model_parallel, world_size)

0 commit comments

Comments
 (0)