Skip to content

Commit 9cb8ab5

Browse files
authored
data: add tensor support to multiplexer provider (#2980)
Summary: This commit implements the new `list_tensors` and `read_tensors` methods for the data provider implementation backed by the event multiplexer. Test Plan: Unit tests included. wchargin-branch: data-tensors-mux
1 parent 1f245b2 commit 9cb8ab5

File tree

3 files changed

+146
-24
lines changed

3 files changed

+146
-24
lines changed

tensorboard/backend/event_processing/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ py_test(
235235
deps = [
236236
":event_accumulator",
237237
":event_multiplexer",
238+
"//tensorboard:expect_numpy_installed",
238239
"//tensorboard:expect_tensorflow_installed",
239240
],
240241
)

tensorboard/backend/event_processing/data_provider.py

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _get_first_event_timestamp(self, run_name):
5757
return None
5858

5959
def data_location(self, experiment_id):
60-
del experiment_id # ignored
60+
del experiment_id # ignored
6161
return str(self._logdir)
6262

6363
def list_runs(self, experiment_id):
@@ -72,8 +72,69 @@ def list_runs(self, experiment_id):
7272
]
7373

7474
def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
75-
del experiment_id # ignored for now
7675
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
76+
return self._list(
77+
provider.ScalarTimeSeries, run_tag_content, run_tag_filter
78+
)
79+
80+
def read_scalars(
81+
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
82+
):
83+
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
84+
# is already downsampled. We could downsample on top of the existing
85+
# sampling, which would be nice for testing.
86+
del downsample # ignored for now
87+
index = self.list_scalars(
88+
experiment_id, plugin_name, run_tag_filter=run_tag_filter
89+
)
90+
91+
def convert_scalar_event(event):
92+
return provider.ScalarDatum(
93+
step=event.step,
94+
wall_time=event.wall_time,
95+
value=tensor_util.make_ndarray(event.tensor_proto).item(),
96+
)
97+
98+
return self._read(convert_scalar_event, index)
99+
100+
def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
101+
run_tag_content = self._multiplexer.PluginRunToTagToContent(plugin_name)
102+
return self._list(
103+
provider.TensorTimeSeries, run_tag_content, run_tag_filter
104+
)
105+
106+
def read_tensors(
107+
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
108+
):
109+
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
110+
# is already downsampled. We could downsample on top of the existing
111+
# sampling, which would be nice for testing.
112+
del downsample # ignored for now
113+
index = self.list_tensors(
114+
experiment_id, plugin_name, run_tag_filter=run_tag_filter
115+
)
116+
117+
def convert_tensor_event(event):
118+
return provider.TensorDatum(
119+
step=event.step,
120+
wall_time=event.wall_time,
121+
numpy=tensor_util.make_ndarray(event.tensor_proto),
122+
)
123+
124+
return self._read(convert_tensor_event, index)
125+
126+
def _list(self, construct_time_series, run_tag_content, run_tag_filter):
127+
"""Helper to list scalar or tensor time series.
128+
129+
Args:
130+
construct_time_series: `ScalarTimeSeries` or `TensorTimeSeries`.
131+
run_tag_content: Result of `_multiplexer.PluginRunToTagToContent(...)`.
132+
run_tag_filter: As given by the client; may be `None`.
133+
134+
Returns:
135+
A list of objects of type given by `construct_time_series`,
136+
suitable to be returned from `list_scalars` or `list_tensors`.
137+
"""
77138
result = {}
78139
if run_tag_filter is None:
79140
run_tag_filter = provider.RunTagFilter(runs=None, tags=None)
@@ -91,7 +152,7 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
91152
if max_wall_time is None or max_wall_time < event.wall_time:
92153
max_wall_time = event.wall_time
93154
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
94-
result_for_run[tag] = provider.ScalarTimeSeries(
155+
result_for_run[tag] = construct_time_series(
95156
max_step=max_step,
96157
max_wall_time=max_wall_time,
97158
plugin_content=summary_metadata.plugin_data.content,
@@ -100,28 +161,23 @@ def list_scalars(self, experiment_id, plugin_name, run_tag_filter=None):
100161
)
101162
return result
102163

103-
def read_scalars(
104-
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
105-
):
106-
# TODO(@wchargin): Downsampling not implemented, as the multiplexer
107-
# is already downsampled. We could downsample on top of the existing
108-
# sampling, which would be nice for testing.
109-
del downsample # ignored for now
110-
index = self.list_scalars(
111-
experiment_id, plugin_name, run_tag_filter=run_tag_filter
112-
)
164+
def _read(self, convert_event, index):
165+
"""Helper to read scalar or tensor data from the multiplexer.
166+
167+
Args:
168+
convert_event: Takes `plugin_event_accumulator.TensorEvent` to
169+
either `provider.ScalarDatum` or `provider.TensorDatum`.
170+
index: The result of `list_scalars` or `list_tensors`.
171+
172+
Returns:
173+
A dict of dicts of values returned by `convert_event` calls,
174+
suitable to be returned from `read_scalars` or `read_tensors`.
175+
"""
113176
result = {}
114177
for (run, tags_for_run) in six.iteritems(index):
115178
result_for_run = {}
116179
result[run] = result_for_run
117180
for (tag, metadata) in six.iteritems(tags_for_run):
118181
events = self._multiplexer.Tensors(run, tag)
119-
result_for_run[tag] = [self._convert_scalar_event(e) for e in events]
182+
result_for_run[tag] = [convert_event(e) for e in events]
120183
return result
121-
122-
def _convert_scalar_event(self, event):
123-
return provider.ScalarDatum(
124-
step=event.step,
125-
wall_time=event.wall_time,
126-
value=tensor_util.make_ndarray(event.tensor_proto).item(),
127-
)

tensorboard/backend/event_processing/data_provider_test.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import six
2424
from six.moves import xrange # pylint: disable=redefined-builtin
25+
import numpy as np
2526

2627
from tensorboard.backend.event_processing import data_provider
2728
from tensorboard.backend.event_processing import (
@@ -64,9 +65,15 @@ def setUp(self):
6465

6566
logdir = os.path.join(self.logdir, "pictures")
6667
with tf.summary.create_file_writer(logdir).as_default():
67-
purple = tf.constant([[[255, 0, 255]]], dtype=tf.uint8)
68-
for i in xrange(1, 11):
69-
image_summary.image("purple", [tf.tile(purple, [i, i, 1])], step=i)
68+
colors = [
69+
("`#F0F`", (255, 0, 255), "purple"),
70+
("`#0F0`", (255, 0, 255), "green"),
71+
]
72+
for (description, rgb, name) in colors:
73+
pixel = tf.constant([[list(rgb)]], dtype=tf.uint8)
74+
for i in xrange(1, 11):
75+
pixels = [tf.tile(pixel, [i, i, 1])]
76+
image_summary.image(name, pixels, step=i, description=description)
7077

7178
def create_multiplexer(self):
7279
multiplexer = event_multiplexer.EventMultiplexer()
@@ -211,6 +218,64 @@ def test_read_scalars_but_not_rank_0(self):
211218
run_tag_filter=run_tag_filter,
212219
)
213220

221+
def test_list_tensors_all(self):
222+
provider = self.create_provider()
223+
result = provider.list_tensors(
224+
experiment_id="unused",
225+
plugin_name=image_metadata.PLUGIN_NAME,
226+
run_tag_filter=None,
227+
)
228+
self.assertItemsEqual(result.keys(), ["pictures"])
229+
self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
230+
sample = result["pictures"]["purple"]
231+
self.assertIsInstance(sample, base_provider.TensorTimeSeries)
232+
self.assertEqual(sample.max_step, 10)
233+
# nothing to test for wall time, as it can't be mocked out
234+
self.assertEqual(sample.plugin_content, b"")
235+
self.assertEqual(sample.display_name, "") # not written by V2 summary ops
236+
self.assertEqual(sample.description, "`#F0F`")
237+
238+
def test_list_tensors_filters(self):
239+
provider = self.create_provider()
240+
241+
# Quick check only, as scalars and tensors use the same underlying
242+
# filtering implementation.
243+
result = provider.list_tensors(
244+
experiment_id="unused",
245+
plugin_name=image_metadata.PLUGIN_NAME,
246+
run_tag_filter=base_provider.RunTagFilter(["pictures"], ["green"]),
247+
)
248+
self.assertItemsEqual(result.keys(), ["pictures"])
249+
self.assertItemsEqual(result["pictures"].keys(), ["green"])
250+
251+
def test_read_tensors(self):
252+
multiplexer = self.create_multiplexer()
253+
provider = data_provider.MultiplexerDataProvider(multiplexer, self.logdir)
254+
255+
run_tag_filter = base_provider.RunTagFilter(
256+
runs=["pictures"],
257+
tags=["purple", "green"],
258+
)
259+
result = provider.read_tensors(
260+
experiment_id="unused",
261+
plugin_name=image_metadata.PLUGIN_NAME,
262+
run_tag_filter=run_tag_filter,
263+
downsample=None, # not yet implemented
264+
)
265+
266+
self.assertItemsEqual(result.keys(), ["pictures"])
267+
self.assertItemsEqual(result["pictures"].keys(), ["purple", "green"])
268+
for run in result:
269+
for tag in result[run]:
270+
tensor_events = multiplexer.Tensors(run, tag)
271+
self.assertLen(result[run][tag], len(tensor_events))
272+
for (datum, event) in zip(result[run][tag], tensor_events):
273+
self.assertEqual(datum.step, event.step)
274+
self.assertEqual(datum.wall_time, event.wall_time)
275+
np.testing.assert_equal(
276+
datum.numpy, tensor_util.make_ndarray(event.tensor_proto)
277+
)
278+
214279

215280
if __name__ == "__main__":
216281
tf.test.main()

0 commit comments

Comments
 (0)