33
33
from tensorboard import plugin_util
34
34
from tensorboard .backend import http_util
35
35
from tensorboard .compat import tf
36
+ from tensorboard .data import provider
36
37
from tensorboard .plugins import base_plugin
37
38
from tensorboard .plugins .histogram import metadata
38
39
from tensorboard .util import tensor_util
@@ -59,8 +60,12 @@ def __init__(self, context):
59
60
Args:
60
61
context: A base_plugin.TBContext instance.
61
62
"""
62
- self ._db_connection_provider = context .db_connection_provider
63
63
self ._multiplexer = context .multiplexer
64
+ self ._db_connection_provider = context .db_connection_provider
65
+ if context .flags and context .flags .generic_data == 'true' :
66
+ self ._data_provider = context .data_provider
67
+ else :
68
+ self ._data_provider = None
64
69
65
70
def get_plugin_apps (self ):
66
71
return {
@@ -70,6 +75,11 @@ def get_plugin_apps(self):
70
75
71
76
def is_active (self ):
72
77
"""This plugin is active iff any run has at least one histograms tag."""
78
+ if self ._data_provider :
79
+ # We don't have an experiment ID, and modifying the backend core
80
+ # to provide one would break backward compatibility. Hack for now.
81
+ return True
82
+
73
83
if self ._db_connection_provider :
74
84
# The plugin is active if one relevant tag can be found in the database.
75
85
db = self ._db_connection_provider ()
@@ -82,10 +92,28 @@ def is_active(self):
82
92
''' , (metadata .PLUGIN_NAME ,))
83
93
return bool (list (cursor ))
84
94
85
- return bool (self ._multiplexer ) and any (self .index_impl ().values ())
95
+ if self ._multiplexer :
96
+ return any (self .index_impl (experiment = '' ).values ())
97
+
98
+ return False
86
99
87
- def index_impl (self ):
100
+ def index_impl (self , experiment ):
88
101
"""Return {runName: {tagName: {displayName: ..., description: ...}}}."""
102
+ if self ._data_provider :
103
+ mapping = self ._data_provider .list_tensors (
104
+ experiment_id = experiment ,
105
+ plugin_name = metadata .PLUGIN_NAME ,
106
+ )
107
+ result = {run : {} for run in mapping }
108
+ for (run , tag_to_content ) in six .iteritems (mapping ):
109
+ for (tag , metadatum ) in six .iteritems (tag_to_content ):
110
+ description = plugin_util .markdown_to_safe_html (metadatum .description )
111
+ result [run ][tag ] = {
112
+ 'displayName' : metadatum .display_name ,
113
+ 'description' : description ,
114
+ }
115
+ return result
116
+
89
117
if self ._db_connection_provider :
90
118
# Read tags from the database.
91
119
db = self ._db_connection_provider ()
@@ -128,7 +156,7 @@ def index_impl(self):
128
156
def frontend_metadata (self ):
129
157
return base_plugin .FrontendMetadata (element_name = 'tf-histogram-dashboard' )
130
158
131
- def histograms_impl (self , tag , run , downsample_to = None ):
159
+ def histograms_impl (self , tag , run , experiment , downsample_to = None ):
132
160
"""Result of the form `(body, mime_type)`.
133
161
134
162
At most `downsample_to` events will be returned. If this value is
@@ -137,7 +165,33 @@ def histograms_impl(self, tag, run, downsample_to=None):
137
165
Raises:
138
166
tensorboard.errors.PublicError: On invalid request.
139
167
"""
140
- if self ._db_connection_provider :
168
+ if self ._data_provider :
169
+ # Downsample reads to 500 histograms per time series, which is
170
+ # the default size guidance for histograms under the multiplexer
171
+ # loading logic.
172
+ SAMPLE_COUNT = downsample_to if downsample_to is not None else 500
173
+ all_histograms = self ._data_provider .read_tensors (
174
+ experiment_id = experiment ,
175
+ plugin_name = metadata .PLUGIN_NAME ,
176
+ downsample = SAMPLE_COUNT ,
177
+ run_tag_filter = provider .RunTagFilter (runs = [run ], tags = [tag ]),
178
+ )
179
+ histograms = all_histograms .get (run , {}).get (tag , None )
180
+ if histograms is None :
181
+ raise errors .NotFoundError (
182
+ "No histogram tag %r for run %r" % (tag , run )
183
+ )
184
+ # Downsample again, even though the data provider is supposed to,
185
+ # because the multiplexer provider currently doesn't. (For
186
+ # well-behaved data providers, this is a no-op.)
187
+ if downsample_to is not None :
188
+ rng = random .Random (0 )
189
+ histograms = _downsample (rng , histograms , downsample_to )
190
+ events = [
191
+ (e .wall_time , e .step , e .numpy .tolist ())
192
+ for e in histograms
193
+ ]
194
+ elif self ._db_connection_provider :
141
195
# Serve data from the database.
142
196
db = self ._db_connection_provider ()
143
197
cursor = db .cursor ()
@@ -205,11 +259,9 @@ def histograms_impl(self, tag, run, downsample_to=None):
205
259
raise errors .NotFoundError (
206
260
'No histogram tag %r for run %r' % (tag , run )
207
261
)
208
- if downsample_to is not None and len (tensor_events ) > downsample_to :
209
- rand_indices = random .Random (0 ).sample (
210
- six .moves .xrange (len (tensor_events )), downsample_to )
211
- indices = sorted (rand_indices )
212
- tensor_events = [tensor_events [i ] for i in indices ]
262
+ if downsample_to is not None :
263
+ rng = random .Random (0 )
264
+ tensor_events = _downsample (rng , tensor_events , downsample_to )
213
265
events = [[e .wall_time , e .step , tensor_util .make_ndarray (e .tensor_proto ).tolist ()]
214
266
for e in tensor_events ]
215
267
return (events , 'application/json' )
@@ -228,14 +280,42 @@ def _get_values(self, data_blob, dtype_enum, shape_string):
228
280
229
281
@wrappers .Request .application
230
282
def tags_route (self , request ):
231
- index = self .index_impl ()
283
+ experiment = plugin_util .experiment_id (request .environ )
284
+ index = self .index_impl (experiment = experiment )
232
285
return http_util .Respond (request , index , 'application/json' )
233
286
234
287
@wrappers .Request .application
235
288
def histograms_route (self , request ):
236
289
"""Given a tag and single run, return array of histogram values."""
290
+ experiment = plugin_util .experiment_id (request .environ )
237
291
tag = request .args .get ('tag' )
238
292
run = request .args .get ('run' )
239
293
(body , mime_type ) = self .histograms_impl (
240
- tag , run , downsample_to = self .SAMPLE_SIZE )
294
+ tag , run , experiment = experiment , downsample_to = self .SAMPLE_SIZE )
241
295
return http_util .Respond (request , body , mime_type )
296
+
297
+
298
+ def _downsample (rng , xs , k ):
299
+ """Uniformly choose a maximal at-most-`k`-subsequence of `xs`.
300
+
301
+ If `k` is larger than `xs`, then the contents of `xs` itself will be
302
+ returned.
303
+
304
+ This differs from `random.sample` in that it returns a subsequence
305
+ (i.e., order is preserved) and that it permits `k > len(xs)`.
306
+
307
+ Args:
308
+ rng: A `random` interface.
309
+ xs: A sequence (`collections.abc.Sequence`).
310
+ k: A non-negative integer.
311
+
312
+ Returns:
313
+ A new list whose elements are a subsequence of `xs` of length
314
+ `min(k, len(xs))`, uniformly selected among such subsequences.
315
+ """
316
+
317
+ if k > len (xs ):
318
+ return list (xs )
319
+ indices = rng .sample (six .moves .xrange (len (xs )), k )
320
+ indices .sort ()
321
+ return [xs [i ] for i in indices ]
0 commit comments