Skip to content

Commit 77e3922

Browse files
committed
Improve the parsable results dump at end of train, stop excessive output, only display top-10.
1 parent 1897f86 commit 77e3922

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

train.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -930,20 +930,29 @@ def main():
930930
# step LR for next epoch
931931
lr_scheduler.step(epoch + 1, latest_metric)
932932

933-
results.append({
933+
latest_results = {
934934
'epoch': epoch,
935935
'train': train_metrics,
936-
'validation': eval_metrics,
937-
})
936+
}
937+
if eval_metrics is not None:
938+
latest_results['validation'] = eval_metrics
939+
results.append(latest_results)
938940

939941
except KeyboardInterrupt:
940942
pass
941943

942-
results = {'all': results}
943944
if best_metric is not None:
944-
results['best'] = results['all'][best_epoch - start_epoch]
945+
# log best metric as tracked by checkpoint saver
945946
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
946-
print(f'--result\n{json.dumps(results, indent=4)}')
947+
948+
if utils.is_primary(args):
949+
# for parsable results display, dump top-10 summaries to avoid excess console spam
950+
display_results = sorted(
951+
results,
952+
key=lambda x: x.get('validation', x.get('train')).get(eval_metric, 0),
953+
reverse=decreasing_metric,
954+
)
955+
print(f'--result\n{json.dumps(display_results[-10:], indent=4)}')
947956

948957

949958
def train_one_epoch(

0 commit comments

Comments
 (0)