@@ -1042,8 +1042,7 @@ def _backward(_loss):
1042
1042
loss = _forward ()
1043
1043
_backward (loss )
1044
1044
1045
- if not args .distributed :
1046
- losses_m .update (loss .item () * accum_steps , input .size (0 ))
1045
+ losses_m .update (loss .item () * accum_steps , input .size (0 ))
1047
1046
update_sample_count += input .size (0 )
1048
1047
1049
1048
if not need_update :
@@ -1068,16 +1067,18 @@ def _backward(_loss):
1068
1067
lrl = [param_group ['lr' ] for param_group in optimizer .param_groups ]
1069
1068
lr = sum (lrl ) / len (lrl )
1070
1069
1070
+ loss_avg , loss_now = losses_m .avg , losses_m .val
1071
1071
if args .distributed :
1072
- reduced_loss = utils .reduce_tensor (loss .data , args .world_size )
1073
- losses_m .update (reduced_loss .item () * accum_steps , input .size (0 ))
1072
+ # synchronize current step and avg loss, each process keeps its own running avg
1073
+ loss_avg = utils .reduce_tensor (loss .new ([loss_avg ]), args .world_size ).item ()
1074
+ loss_now = utils .reduce_tensor (loss .new ([loss_now ]), args .world_size ).item ()
1074
1075
update_sample_count *= args .world_size
1075
1076
1076
1077
if utils .is_primary (args ):
1077
1078
_logger .info (
1078
1079
f'Train: { epoch } [{ update_idx :>4d} /{ updates_per_epoch } '
1079
1080
f'({ 100. * (update_idx + 1 ) / updates_per_epoch :>3.0f} %)] '
1080
- f'Loss: { losses_m . val :#.3g} ({ losses_m . avg :#.3g} ) '
1081
+ f'Loss: { loss_now :#.3g} ({ loss_avg :#.3g} ) '
1081
1082
f'Time: { update_time_m .val :.3f} s, { update_sample_count / update_time_m .val :>7.2f} /s '
1082
1083
f'({ update_time_m .avg :.3f} s, { update_sample_count / update_time_m .avg :>7.2f} /s) '
1083
1084
f'LR: { lr :.3e} '
@@ -1106,7 +1107,12 @@ def _backward(_loss):
1106
1107
if hasattr (optimizer , 'sync_lookahead' ):
1107
1108
optimizer .sync_lookahead ()
1108
1109
1109
- return OrderedDict ([('loss' , losses_m .avg )])
1110
+ loss_avg = losses_m .avg
1111
+ if args .distributed :
1112
+ # synchronize avg loss, each process keeps its own running avg
1113
+ loss_avg = torch .tensor ([loss_avg ], device = device , dtype = torch .float32 )
1114
+ loss_avg = utils .reduce_tensor (loss_avg , args .world_size ).item ()
1115
+ return OrderedDict ([('loss' , loss_avg )])
1110
1116
1111
1117
1112
1118
def validate (
0 commit comments