Skip to content

Commit 1897f86

Browse files
committed
In dist training, update loss running avg every step, only sync on log updates / final.
1 parent ae0737f commit 1897f86

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

train.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,8 +1042,7 @@ def _backward(_loss):
10421042
loss = _forward()
10431043
_backward(loss)
10441044

1045-
if not args.distributed:
1046-
losses_m.update(loss.item() * accum_steps, input.size(0))
1045+
losses_m.update(loss.item() * accum_steps, input.size(0))
10471046
update_sample_count += input.size(0)
10481047

10491048
if not need_update:
@@ -1068,16 +1067,18 @@ def _backward(_loss):
10681067
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
10691068
lr = sum(lrl) / len(lrl)
10701069

1070+
loss_avg, loss_now = losses_m.avg, losses_m.val
10711071
if args.distributed:
1072-
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1073-
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
1072+
# synchronize current step and avg loss, each process keeps its own running avg
1073+
loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item()
1074+
loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item()
10741075
update_sample_count *= args.world_size
10751076

10761077
if utils.is_primary(args):
10771078
_logger.info(
10781079
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
10791080
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
1080-
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
1081+
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
10811082
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
10821083
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
10831084
f'LR: {lr:.3e} '
@@ -1106,7 +1107,12 @@ def _backward(_loss):
11061107
if hasattr(optimizer, 'sync_lookahead'):
11071108
optimizer.sync_lookahead()
11081109

1109-
return OrderedDict([('loss', losses_m.avg)])
1110+
loss_avg = losses_m.avg
1111+
if args.distributed:
1112+
# synchronize avg loss, each process keeps its own running avg
1113+
loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32)
1114+
loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item()
1115+
return OrderedDict([('loss', loss_avg)])
11101116

11111117

11121118
def validate(

0 commit comments

Comments
 (0)