Skip to content

Commit 3634b55

Browse files
lminerWindQAQ
authored andcommitted
tqdm status bar custom metrics formatting (#916)
This is a simple change that will allow custom formatting of metrics for the tqdm status bar. This is especially useful if you have small loss values and want to see it in scientific notation. Replace `metrics_format` with `'{name}: {value:e}'`
1 parent 9bdf0c2 commit 3634b55

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tensorflow_addons/callbacks/tqdm_progress_bar.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class TQDMProgressBar(Callback):
3737
update_per_second (int): Maximum number of updates in the epochs bar
3838
per second, this is to prevent small batches from slowing down
3939
training. Defaults to 10.
40+
metrics_format (string): Custom format for how metrics are formatted.
41+
See https://github.com/tqdm/tqdm#parameters for more detail.
4042
leave_epoch_progress (bool): True to leave epoch progress bars
4143
leave_overall_progress (bool): True to leave overall progress bar
4244
show_epoch_progress (bool): False to hide epoch progress bars
@@ -49,6 +51,7 @@ def __init__(self,
4951
'{remaining}s, {rate_fmt}{postfix}',
5052
epoch_bar_format='{n_fmt}/{total_fmt}{bar} ETA: '
5153
'{remaining}s - {desc}',
54+
metrics_format='{name}: {value:0.4f}',
5255
update_per_second=10,
5356
leave_epoch_progress=True,
5457
leave_overall_progress=True,
@@ -75,6 +78,7 @@ def __init__(self,
7578
self.leave_overall_progress = leave_overall_progress
7679
self.show_epoch_progress = show_epoch_progress
7780
self.show_overall_progress = show_overall_progress
81+
self.metrics_format = metrics_format
7882

7983
# compute update interval (inverse of update per second)
8084
self.update_interval = 1 / update_per_second
@@ -196,7 +200,7 @@ def format_metrics(self, logs={}, factor=1):
196200
for metric in self.metrics:
197201
if metric in logs:
198202
value = logs[metric] / factor
199-
pair = '{name}: {value:0.4f}'.format(name=metric, value=value)
203+
pair = self.metrics_format.format(name=metric, value=value)
200204
metric_value_pairs.append(pair)
201205
metrics_string = self.metrics_separator.join(metric_value_pairs)
202206
return metrics_string

0 commit comments

Comments
 (0)