Skip to content

Commit c5985a8

Browse files
daquexiansoumith
authored andcommitted
Divide args.workers by ngpus_per_node (#485)
1 parent 44053c5 commit c5985a8

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

imagenet/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def main_worker(gpu, ngpus_per_node, args):
148148
# DistributedDataParallel, we need to divide the batch size
149149
# ourselves based on the total number of GPUs we have
150150
args.batch_size = int(args.batch_size / ngpus_per_node)
151+
args.workers = int(args.workers / ngpus_per_node)
151152
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
152153
else:
153154
model.cuda()

0 commit comments

Comments
 (0)