We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ee27b73 commit 39eb56fCopy full SHA for 39eb56f
train.py
@@ -1142,7 +1142,10 @@ def _backward(_loss):
1142
1143
if args.distributed:
1144
# scale gradient btw distributed ranks, each one can have different batch size
1145
- global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
+ global_batch_size = utils.reduce_tensor(
1146
+ torch.tensor(batch_size, device=device, dtype=torch.float32),
1147
+ 1 # SUM
1148
+ )
1149
dist_scale = args.world_size * batch_size / global_batch_size
1150
else:
1151
dist_scale = None
0 commit comments