Skip to content

Commit 1f245b2

Browse files
authored
data: add tensor support to data provider interface (#2979)
Summary: This commit specifies the `list_tensors` and `read_tensors` methods on the data provider interface. These methods are optional for now (i.e., not decorated with `abc.abstractmethod`) for compatibility reasons, but we’ll make them required soon. Test Plan: Unit tests included. wchargin-branch: data-tensors-interface
1 parent 374d202 commit 1f245b2

File tree

3 files changed

+222
-4
lines changed

3 files changed

+222
-4
lines changed

tensorboard/data/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
srcs = ["provider.py"],
1313
srcs_version = "PY2AND3",
1414
deps = [
15+
"//tensorboard:expect_numpy_installed",
1516
"@org_pythonhosted_six",
1617
],
1718
)
@@ -24,6 +25,7 @@ py_test(
2425
tags = ["support_notf"],
2526
deps = [
2627
":provider",
28+
"//tensorboard:expect_numpy_installed",
2729
"//tensorboard:test",
2830
"@org_pythonhosted_six",
2931
],

tensorboard/data/provider.py

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import collections
2323

2424
import six
25+
import numpy as np
2526

2627

2728
@six.add_metaclass(abc.ABCMeta)
@@ -120,12 +121,56 @@ def read_scalars(
120121
"""
121122
pass
122123

123-
def list_tensors(self):
124-
"""Not yet specified."""
124+
def list_tensors(self, experiment_id, plugin_name, run_tag_filter=None):
125+
"""List metadata about tensor time series.
126+
127+
Args:
128+
experiment_id: ID of enclosing experiment.
129+
plugin_name: String name of the TensorBoard plugin that created
130+
the data to be queried. Required.
131+
run_tag_filter: Optional `RunTagFilter` value. If omitted, all
132+
runs and tags will be included.
133+
134+
The result will only contain keys for run-tag combinations that
135+
actually exist, which may not include all entries in the
136+
`run_tag_filter`.
137+
138+
Returns:
139+
A nested map `d` such that `d[run][tag]` is a `TensorTimeSeries`
140+
value.
141+
142+
Raises:
143+
tensorboard.errors.PublicError: See `DataProvider` class docstring.
144+
"""
125145
pass
126146

127-
def read_tensors(self):
128-
"""Not yet specified."""
147+
def read_tensors(
148+
self, experiment_id, plugin_name, downsample=None, run_tag_filter=None
149+
):
150+
"""Read values from tensor time series.
151+
152+
Args:
153+
experiment_id: ID of enclosing experiment.
154+
plugin_name: String name of the TensorBoard plugin that created
155+
the data to be queried. Required.
156+
downsample: Integer number of steps to which to downsample the
157+
results (e.g., `1000`). Required.
158+
run_tag_filter: Optional `RunTagFilter` value. If provided, a time
159+
series will only be included in the result if its run and tag
160+
both pass this filter. If `None`, all time series will be
161+
included.
162+
163+
The result will only contain keys for run-tag combinations that
164+
actually exist, which may not include all entries in the
165+
`run_tag_filter`.
166+
167+
Returns:
168+
A nested map `d` such that `d[run][tag]` is a list of
169+
`TensorDatum` values sorted by step.
170+
171+
Raises:
172+
tensorboard.errors.PublicError: See `DataProvider` class docstring.
173+
"""
129174
pass
130175

131176
def list_blob_sequences(
@@ -392,6 +437,110 @@ def __repr__(self):
392437
))
393438

394439

440+
class TensorTimeSeries(_TimeSeries):
441+
"""Metadata about a tensor time series for a particular run and tag.
442+
443+
Attributes:
444+
max_step: The largest step value of any datum in this tensor time series; a
445+
nonnegative integer.
446+
max_wall_time: The largest wall time of any datum in this time series, as
447+
`float` seconds since epoch.
448+
plugin_content: A bytestring of arbitrary plugin-specific metadata for this
449+
time series, as provided to `tf.summary.write` in the
450+
`plugin_data.content` field of the `metadata` argument.
451+
description: An optional long-form Markdown description, as a `str` that is
452+
empty if no description was specified.
453+
display_name: An optional long-form Markdown description, as a `str` that is
454+
empty if no description was specified. Deprecated; may be removed soon.
455+
"""
456+
457+
def __eq__(self, other):
458+
if not isinstance(other, TensorTimeSeries):
459+
return False
460+
if self._max_step != other._max_step:
461+
return False
462+
if self._max_wall_time != other._max_wall_time:
463+
return False
464+
if self._plugin_content != other._plugin_content:
465+
return False
466+
if self._description != other._description:
467+
return False
468+
if self._display_name != other._display_name:
469+
return False
470+
return True
471+
472+
def __hash__(self):
473+
return hash((
474+
self._max_step,
475+
self._max_wall_time,
476+
self._plugin_content,
477+
self._description,
478+
self._display_name,
479+
))
480+
481+
def __repr__(self):
482+
return "TensorTimeSeries(%s)" % ", ".join((
483+
"max_step=%r" % (self._max_step,),
484+
"max_wall_time=%r" % (self._max_wall_time,),
485+
"plugin_content=%r" % (self._plugin_content,),
486+
"description=%r" % (self._description,),
487+
"display_name=%r" % (self._display_name,),
488+
))
489+
490+
491+
class TensorDatum(object):
492+
"""A single datum in a tensor time series for a run and tag.
493+
494+
Attributes:
495+
step: The global step at which this datum occurred; an integer. This
496+
is a unique key among data of this time series.
497+
wall_time: The real-world time at which this datum occurred, as
498+
`float` seconds since epoch.
499+
numpy: The `numpy.ndarray` value with the tensor contents of this
500+
datum.
501+
"""
502+
503+
__slots__ = ("_step", "_wall_time", "_numpy")
504+
505+
def __init__(self, step, wall_time, numpy):
506+
self._step = step
507+
self._wall_time = wall_time
508+
self._numpy = numpy
509+
510+
@property
511+
def step(self):
512+
return self._step
513+
514+
@property
515+
def wall_time(self):
516+
return self._wall_time
517+
518+
@property
519+
def numpy(self):
520+
return self._numpy
521+
522+
def __eq__(self, other):
523+
if not isinstance(other, TensorDatum):
524+
return False
525+
if self._step != other._step:
526+
return False
527+
if self._wall_time != other._wall_time:
528+
return False
529+
if not np.array_equal(self._numpy, other._numpy):
530+
return False
531+
return True
532+
533+
# Unhashable type: numpy arrays are mutable.
534+
__hash__ = None
535+
536+
def __repr__(self):
537+
return "TensorDatum(%s)" % ", ".join((
538+
"step=%r" % (self._step,),
539+
"wall_time=%r" % (self._wall_time,),
540+
"numpy=%r" % (self._numpy,),
541+
))
542+
543+
395544
class BlobSequenceTimeSeries(_TimeSeries):
396545
"""Metadata about a blob sequence time series for a particular run and tag.
397546

tensorboard/data/provider_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import numpy as np
2122
import six
2223

2324
from tensorboard import test as tb_test
@@ -109,6 +110,72 @@ def test_hash(self):
109110
self.assertNotEqual(hash(x1), hash(x3))
110111

111112

113+
class TensorTimeSeriesTest(tb_test.TestCase):
114+
def test_repr(self):
115+
x = provider.TensorTimeSeries(
116+
max_step=77,
117+
max_wall_time=1234.5,
118+
plugin_content=b"AB\xCD\xEF!\x00",
119+
description="test test",
120+
display_name="one two",
121+
)
122+
repr_ = repr(x)
123+
self.assertIn(repr(x.max_step), repr_)
124+
self.assertIn(repr(x.max_wall_time), repr_)
125+
self.assertIn(repr(x.plugin_content), repr_)
126+
self.assertIn(repr(x.description), repr_)
127+
self.assertIn(repr(x.display_name), repr_)
128+
129+
def test_eq(self):
130+
x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
131+
x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
132+
x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
133+
self.assertEqual(x1, x2)
134+
self.assertNotEqual(x1, x3)
135+
self.assertNotEqual(x1, object())
136+
137+
def test_hash(self):
138+
x1 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
139+
x2 = provider.TensorTimeSeries(77, 1234.5, b"\x12", "one", "two")
140+
x3 = provider.TensorTimeSeries(66, 4321.0, b"\x7F", "hmm", "hum")
141+
self.assertEqual(hash(x1), hash(x2))
142+
# The next check is technically not required by the `__hash__`
143+
# contract, but _should_ pass; failure on this assertion would at
144+
# least warrant some scrutiny.
145+
self.assertNotEqual(hash(x1), hash(x3))
146+
147+
148+
class TensorDatumTest(tb_test.TestCase):
149+
def test_repr(self):
150+
x = provider.TensorDatum(step=123, wall_time=234.5, numpy=np.array(-0.25))
151+
repr_ = repr(x)
152+
self.assertIn(repr(x.step), repr_)
153+
self.assertIn(repr(x.wall_time), repr_)
154+
self.assertIn(repr(x.numpy), repr_)
155+
156+
def test_eq(self):
157+
nd = np.array
158+
x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
159+
x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=nd([1.0, 2.0]))
160+
x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=nd([-0.5, -2.5]))
161+
self.assertEqual(x1, x2)
162+
self.assertNotEqual(x1, x3)
163+
self.assertNotEqual(x1, object())
164+
165+
def test_eq_with_rank0_tensor(self):
166+
x1 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
167+
x2 = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
168+
x3 = provider.TensorDatum(step=23, wall_time=3.25, numpy=np.array([1.25]))
169+
self.assertEqual(x1, x2)
170+
self.assertNotEqual(x1, x3)
171+
self.assertNotEqual(x1, object())
172+
173+
def test_hash(self):
174+
x = provider.TensorDatum(step=12, wall_time=0.25, numpy=np.array([1.25]))
175+
with six.assertRaisesRegex(self, TypeError, "unhashable type"):
176+
hash(x)
177+
178+
112179
class BlobSequenceTimeSeriesTest(tb_test.TestCase):
113180

114181
def test_repr(self):

0 commit comments

Comments
 (0)