Skip to content

Commit 7b6b5df

Browse files
authored
data: add histograms plugin integration (#2981)
Summary: This patch teaches the histograms plugin about the generic data APIs and their newfound tensor support, guarded behind the `--generic_data=true` feature flag. The distributions plugin tags along for free. Testing: After this change, setting `--generic_data` to either `false` or `true` does not change the visual appearance of the distributions or histograms dashboards nor the outputs of either the `/data/tags` routes or the `/data/histograms` or `/data/distributions` routes. (Replacing `e.numpy` with `(e.numpy + 100)` in `histograms_plugin.py` does cause these checks to fail, so everything is wired up properly.) Unit tests updated for the histograms plugin following the form of those in the scalars plugin. Tests for the distributions plugin tests are left unchanged, as it never actually touches its context or multiplexer. wchargin-branch: generic-histograms
1 parent a69d76c commit 7b6b5df

File tree

6 files changed

+193
-66
lines changed

6 files changed

+193
-66
lines changed

tensorboard/plugins/distribution/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ py_library(
1515
visibility = ["//visibility:public"],
1616
deps = [
1717
":compressor",
18+
"//tensorboard:plugin_util",
1819
"//tensorboard/backend:http_util",
1920
"//tensorboard/plugins:base_plugin",
2021
"//tensorboard/plugins/histogram:histograms_plugin",

tensorboard/plugins/distribution/distributions_plugin.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from werkzeug import wrappers
2626

27+
from tensorboard import plugin_util
2728
from tensorboard.backend import http_util
2829
from tensorboard.plugins import base_plugin
2930
from tensorboard.plugins.distribution import compressor
@@ -52,7 +53,6 @@ def __init__(self, context):
5253
context: A base_plugin.TBContext instance.
5354
"""
5455
self._histograms_plugin = histograms_plugin.HistogramsPlugin(context)
55-
self._multiplexer = context.multiplexer
5656

5757
def get_plugin_apps(self):
5858
return {
@@ -73,14 +73,14 @@ def frontend_metadata(self):
7373
element_name='tf-distribution-dashboard',
7474
)
7575

76-
def distributions_impl(self, tag, run):
76+
def distributions_impl(self, tag, run, experiment):
7777
"""Result of the form `(body, mime_type)`.
7878
7979
Raises:
8080
tensorboard.errors.PublicError: On invalid request.
8181
"""
8282
(histograms, mime_type) = self._histograms_plugin.histograms_impl(
83-
tag, run, downsample_to=self.SAMPLE_SIZE)
83+
tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE)
8484
return ([self._compress(histogram) for histogram in histograms],
8585
mime_type)
8686

@@ -89,18 +89,20 @@ def _compress(self, histogram):
8989
converted_buckets = compressor.compress_histogram(buckets)
9090
return [wall_time, step, converted_buckets]
9191

92-
def index_impl(self):
93-
return self._histograms_plugin.index_impl()
92+
def index_impl(self, experiment):
93+
return self._histograms_plugin.index_impl(experiment=experiment)
9494

9595
@wrappers.Request.application
9696
def tags_route(self, request):
97-
index = self.index_impl()
97+
experiment = plugin_util.experiment_id(request.environ)
98+
index = self.index_impl(experiment=experiment)
9899
return http_util.Respond(request, index, 'application/json')
99100

100101
@wrappers.Request.application
101102
def distributions_route(self, request):
102103
"""Given a tag and single run, return an array of compressed histograms."""
104+
experiment = plugin_util.experiment_id(request.environ)
103105
tag = request.args.get('tag')
104106
run = request.args.get('run')
105-
(body, mime_type) = self.distributions_impl(tag, run)
107+
(body, mime_type) = self.distributions_impl(tag, run, experiment=experiment)
106108
return http_util.Respond(request, body, mime_type)

tensorboard/plugins/distribution/distributions_plugin_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,16 @@ def test_index(self):
121121
'description': self._HTML_DESCRIPTION,
122122
},
123123
},
124-
}, self.plugin.index_impl())
124+
}, self.plugin.index_impl(experiment='exp'))
125125

126126
def _test_distributions(self, run_name, tag_name, should_work=True):
127127
self.set_up_with_runs([self._RUN_WITH_SCALARS,
128128
self._RUN_WITH_LEGACY_DISTRIBUTION,
129129
self._RUN_WITH_DISTRIBUTION])
130130
if should_work:
131-
(data, mime_type) = self.plugin.distributions_impl(tag_name, run_name)
131+
(data, mime_type) = self.plugin.distributions_impl(
132+
tag_name, run_name, experiment='exp'
133+
)
132134
self.assertEqual('application/json', mime_type)
133135
self.assertEqual(len(data), self._STEPS)
134136
for i in xrange(self._STEPS):
@@ -138,7 +140,9 @@ def _test_distributions(self, run_name, tag_name, should_work=True):
138140
self.assertEqual(bps, compressor.NORMAL_HISTOGRAM_BPS)
139141
else:
140142
with self.assertRaises(errors.NotFoundError):
141-
self.plugin.distributions_impl(self._DISTRIBUTION_TAG, run_name)
143+
self.plugin.distributions_impl(
144+
self._DISTRIBUTION_TAG, run_name, experiment='exp'
145+
)
142146

143147
def test_distributions_with_scalars(self):
144148
self._test_distributions(self._RUN_WITH_SCALARS, self._DISTRIBUTION_TAG,

tensorboard/plugins/histogram/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ py_library(
2121
"//tensorboard:expect_numpy_installed",
2222
"//tensorboard:plugin_util",
2323
"//tensorboard/backend:http_util",
24+
"//tensorboard/data:provider",
2425
"//tensorboard/plugins:base_plugin",
2526
"//tensorboard/util:tensor_util",
2627
"@org_pocoo_werkzeug",
@@ -42,6 +43,7 @@ py_test(
4243
"//tensorboard/backend:application",
4344
"//tensorboard/backend/event_processing:event_accumulator",
4445
"//tensorboard/backend/event_processing:event_multiplexer",
46+
"//tensorboard/data:provider",
4547
"//tensorboard/plugins:base_plugin",
4648
"//tensorboard/util:test_util",
4749
"@org_pocoo_werkzeug",

tensorboard/plugins/histogram/histograms_plugin.py

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tensorboard import plugin_util
3434
from tensorboard.backend import http_util
3535
from tensorboard.compat import tf
36+
from tensorboard.data import provider
3637
from tensorboard.plugins import base_plugin
3738
from tensorboard.plugins.histogram import metadata
3839
from tensorboard.util import tensor_util
@@ -59,8 +60,12 @@ def __init__(self, context):
5960
Args:
6061
context: A base_plugin.TBContext instance.
6162
"""
62-
self._db_connection_provider = context.db_connection_provider
6363
self._multiplexer = context.multiplexer
64+
self._db_connection_provider = context.db_connection_provider
65+
if context.flags and context.flags.generic_data == 'true':
66+
self._data_provider = context.data_provider
67+
else:
68+
self._data_provider = None
6469

6570
def get_plugin_apps(self):
6671
return {
@@ -70,6 +75,11 @@ def get_plugin_apps(self):
7075

7176
def is_active(self):
7277
"""This plugin is active iff any run has at least one histograms tag."""
78+
if self._data_provider:
79+
# We don't have an experiment ID, and modifying the backend core
80+
# to provide one would break backward compatibility. Hack for now.
81+
return True
82+
7383
if self._db_connection_provider:
7484
# The plugin is active if one relevant tag can be found in the database.
7585
db = self._db_connection_provider()
@@ -82,10 +92,28 @@ def is_active(self):
8292
''', (metadata.PLUGIN_NAME,))
8393
return bool(list(cursor))
8494

85-
return bool(self._multiplexer) and any(self.index_impl().values())
95+
if self._multiplexer:
96+
return any(self.index_impl(experiment='').values())
97+
98+
return False
8699

87-
def index_impl(self):
100+
def index_impl(self, experiment):
88101
"""Return {runName: {tagName: {displayName: ..., description: ...}}}."""
102+
if self._data_provider:
103+
mapping = self._data_provider.list_tensors(
104+
experiment_id=experiment,
105+
plugin_name=metadata.PLUGIN_NAME,
106+
)
107+
result = {run: {} for run in mapping}
108+
for (run, tag_to_content) in six.iteritems(mapping):
109+
for (tag, metadatum) in six.iteritems(tag_to_content):
110+
description = plugin_util.markdown_to_safe_html(metadatum.description)
111+
result[run][tag] = {
112+
'displayName': metadatum.display_name,
113+
'description': description,
114+
}
115+
return result
116+
89117
if self._db_connection_provider:
90118
# Read tags from the database.
91119
db = self._db_connection_provider()
@@ -128,7 +156,7 @@ def index_impl(self):
128156
def frontend_metadata(self):
129157
return base_plugin.FrontendMetadata(element_name='tf-histogram-dashboard')
130158

131-
def histograms_impl(self, tag, run, downsample_to=None):
159+
def histograms_impl(self, tag, run, experiment, downsample_to=None):
132160
"""Result of the form `(body, mime_type)`.
133161
134162
At most `downsample_to` events will be returned. If this value is
@@ -137,7 +165,33 @@ def histograms_impl(self, tag, run, downsample_to=None):
137165
Raises:
138166
tensorboard.errors.PublicError: On invalid request.
139167
"""
140-
if self._db_connection_provider:
168+
if self._data_provider:
169+
# Downsample reads to 500 histograms per time series, which is
170+
# the default size guidance for histograms under the multiplexer
171+
# loading logic.
172+
SAMPLE_COUNT = downsample_to if downsample_to is not None else 500
173+
all_histograms = self._data_provider.read_tensors(
174+
experiment_id=experiment,
175+
plugin_name=metadata.PLUGIN_NAME,
176+
downsample=SAMPLE_COUNT,
177+
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
178+
)
179+
histograms = all_histograms.get(run, {}).get(tag, None)
180+
if histograms is None:
181+
raise errors.NotFoundError(
182+
"No histogram tag %r for run %r" % (tag, run)
183+
)
184+
# Downsample again, even though the data provider is supposed to,
185+
# because the multiplexer provider currently doesn't. (For
186+
# well-behaved data providers, this is a no-op.)
187+
if downsample_to is not None:
188+
rng = random.Random(0)
189+
histograms = _downsample(rng, histograms, downsample_to)
190+
events = [
191+
(e.wall_time, e.step, e.numpy.tolist())
192+
for e in histograms
193+
]
194+
elif self._db_connection_provider:
141195
# Serve data from the database.
142196
db = self._db_connection_provider()
143197
cursor = db.cursor()
@@ -205,11 +259,9 @@ def histograms_impl(self, tag, run, downsample_to=None):
205259
raise errors.NotFoundError(
206260
'No histogram tag %r for run %r' % (tag, run)
207261
)
208-
if downsample_to is not None and len(tensor_events) > downsample_to:
209-
rand_indices = random.Random(0).sample(
210-
six.moves.xrange(len(tensor_events)), downsample_to)
211-
indices = sorted(rand_indices)
212-
tensor_events = [tensor_events[i] for i in indices]
262+
if downsample_to is not None:
263+
rng = random.Random(0)
264+
tensor_events = _downsample(rng, tensor_events, downsample_to)
213265
events = [[e.wall_time, e.step, tensor_util.make_ndarray(e.tensor_proto).tolist()]
214266
for e in tensor_events]
215267
return (events, 'application/json')
@@ -228,14 +280,42 @@ def _get_values(self, data_blob, dtype_enum, shape_string):
228280

229281
@wrappers.Request.application
230282
def tags_route(self, request):
231-
index = self.index_impl()
283+
experiment = plugin_util.experiment_id(request.environ)
284+
index = self.index_impl(experiment=experiment)
232285
return http_util.Respond(request, index, 'application/json')
233286

234287
@wrappers.Request.application
235288
def histograms_route(self, request):
236289
"""Given a tag and single run, return array of histogram values."""
290+
experiment = plugin_util.experiment_id(request.environ)
237291
tag = request.args.get('tag')
238292
run = request.args.get('run')
239293
(body, mime_type) = self.histograms_impl(
240-
tag, run, downsample_to=self.SAMPLE_SIZE)
294+
tag, run, experiment=experiment, downsample_to=self.SAMPLE_SIZE)
241295
return http_util.Respond(request, body, mime_type)
296+
297+
298+
def _downsample(rng, xs, k):
299+
"""Uniformly choose a maximal at-most-`k`-subsequence of `xs`.
300+
301+
If `k` is larger than `xs`, then the contents of `xs` itself will be
302+
returned.
303+
304+
This differs from `random.sample` in that it returns a subsequence
305+
(i.e., order is preserved) and that it permits `k > len(xs)`.
306+
307+
Args:
308+
rng: A `random` interface.
309+
xs: A sequence (`collections.abc.Sequence`).
310+
k: A non-negative integer.
311+
312+
Returns:
313+
A new list whose elements are a subsequence of `xs` of length
314+
`min(k, len(xs))`, uniformly selected among such subsequences.
315+
"""
316+
317+
if k > len(xs):
318+
return list(xs)
319+
indices = rng.sample(six.moves.xrange(len(xs)), k)
320+
indices.sort()
321+
return [xs[i] for i in indices]

0 commit comments

Comments
 (0)