File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -265,8 +265,8 @@ either the application or the model ``forward()`` method.
265
265
setup(rank, world_size)
266
266
267
267
# setup mp_model and devices for this process
268
- dev0 = rank * 2
269
- dev1 = rank * 2 + 1
268
+ dev0 = ( rank * 2 ) % world_size
269
+ dev1 = ( rank * 2 + 1 ) % world_size
270
270
mp_model = ToyMpModel(dev0, dev1)
271
271
ddp_mp_model = DDP(mp_model)
272
272
@@ -285,9 +285,8 @@ either the application or the model ``forward()`` method.
285
285
286
286
if __name__ == " __main__" :
287
287
n_gpus = torch.cuda.device_count()
288
- if n_gpus < 8 :
289
- print (f " Requires at least 8 GPUs to run, but got { n_gpus} . " )
290
- else :
291
- run_demo(demo_basic, 8 )
292
- run_demo(demo_checkpoint, 8 )
293
- run_demo(demo_model_parallel, 4 )
288
+ assert n_gpus >= 2 , f " Requires at least 2 GPUs to run, but got { n_gpus} "
289
+ world_size = n_gpus
290
+ run_demo(demo_basic, world_size)
291
+ run_demo(demo_checkpoint, world_size)
292
+ run_demo(demo_model_parallel, world_size)
You can’t perform that action at this time.
0 commit comments