Skip to content

Commit 909bf85

Browse files
wcharginnfelt
authored andcommitted
uploader: require experiment metadata from server (#3310)
Summary: The `StreamExperiments` RPC response used to send only `experiment_ids`, but now sends `experiments` with additional metadata. This server-side change has been live since late November 2019, so we’re confident that we won’t need to roll it back. Thus, we can drop compatibility for the old code path in new uploader clients. Test Plan: Unit tests updated; verified that the `list` and `export` subcommands still work. wchargin-branch: uploader-require-experiment-metadata
1 parent 06e58a6 commit 909bf85

File tree

4 files changed

+58
-67
lines changed

4 files changed

+58
-67
lines changed

tensorboard/uploader/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ py_library(
2020
"//tensorboard:expect_grpc_installed",
2121
"//tensorboard/uploader/proto:protos_all_py_pb2",
2222
"//tensorboard/util:grpc_util",
23+
"//tensorboard/util:tb_logging",
2324
"@org_pythonhosted_six",
2425
],
2526
)

tensorboard/uploader/exporter.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tensorboard.uploader.proto import export_service_pb2
3333
from tensorboard.uploader import util
3434
from tensorboard.util import grpc_util
35+
from tensorboard.util import tb_logging
3536

3637
# Characters that are assumed to be safe in filenames. Note that the
3738
# server's experiment IDs are base64 encodings of 16-byte blobs, so they
@@ -47,6 +48,8 @@
4748
# Output filename for scalar data within an experiment directory.
4849
_FILENAME_SCALARS = "scalars.json"
4950

51+
logger = tb_logging.get_logger()
52+
5053

5154
class TensorBoardExporter(object):
5255
"""Exports all of the user's experiment data from TensorBoard.dev.
@@ -115,7 +118,8 @@ def export(self, read_time=None):
115118
"""
116119
if read_time is None:
117120
read_time = time.time()
118-
for experiment_id in self._request_experiment_ids(read_time):
121+
for experiment in list_experiments(self._api, read_time=read_time):
122+
experiment_id = experiment.experiment_id
119123
experiment_dir = _experiment_directory(self._outdir, experiment_id)
120124
os.mkdir(experiment_dir)
121125

@@ -134,18 +138,6 @@ def export(self, read_time=None):
134138
else:
135139
raise
136140

137-
def _request_experiment_ids(self, read_time):
138-
"""Yields all of the calling user's experiment IDs, as strings."""
139-
for experiment in list_experiments(self._api, read_time=read_time):
140-
if isinstance(experiment, experiment_pb2.Experiment):
141-
yield experiment.experiment_id
142-
elif isinstance(experiment, six.string_types):
143-
yield experiment
144-
else:
145-
raise AssertionError(
146-
"Unexpected experiment type: %r" % (experiment,)
147-
)
148-
149141
def _request_scalar_data(self, experiment_id, read_time):
150142
"""Yields JSON-serializable blocks of scalar data."""
151143
request = export_service_pb2.StreamExperimentDataRequest()
@@ -191,7 +183,11 @@ def list_experiments(api_client, fieldmask=None, read_time=None):
191183
192184
Yields:
193185
For each experiment owned by the user, an `experiment_pb2.Experiment`
194-
value, or a simple string experiment ID for older servers.
186+
value.
187+
188+
Raises:
189+
RuntimeError: If the server returns experiment IDs but no experiments,
190+
as in an old, unsupported version of the protocol.
195191
"""
196192
if read_time is None:
197193
read_time = time.time()
@@ -206,10 +202,17 @@ def list_experiments(api_client, fieldmask=None, read_time=None):
206202
if response.experiments:
207203
for experiment in response.experiments:
208204
yield experiment
205+
elif response.experiment_ids:
206+
raise RuntimeError(
207+
"Server sent experiment_ids without experiments: <%r>"
208+
% (list(response.experiment_ids),)
209+
)
209210
else:
210-
# Old servers.
211-
for experiment_id in response.experiment_ids:
212-
yield experiment_id
211+
# No data: not technically a problem, but not expected.
212+
logger.warn(
213+
"StreamExperiments RPC returned response with no experiments: <%r>",
214+
response,
215+
)
213216

214217

215218
class OutputDirectoryExistsError(ValueError):

tensorboard/uploader/exporter_test.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,29 @@
4343
from tensorboard.compat.proto import summary_pb2
4444

4545

46+
def _make_experiments_response(eids):
47+
"""Make a `StreamExperimentsResponse` with experiments with only IDs."""
48+
response = export_service_pb2.StreamExperimentsResponse()
49+
for eid in eids:
50+
response.experiments.add(experiment_id=eid)
51+
return response
52+
53+
4654
class TensorBoardExporterTest(tb_test.TestCase):
4755
def _create_mock_api_client(self):
4856
return _create_mock_api_client()
4957

50-
def _make_experiments_response(self, eids):
51-
return export_service_pb2.StreamExperimentsResponse(experiment_ids=eids)
52-
5358
def test_e2e_success_case(self):
5459
mock_api_client = self._create_mock_api_client()
5560
mock_api_client.StreamExperiments.return_value = iter(
56-
[
57-
export_service_pb2.StreamExperimentsResponse(
58-
experiment_ids=["789"]
59-
),
60-
]
61+
[_make_experiments_response(["789"])]
6162
)
6263

6364
def stream_experiments(request, **kwargs):
6465
del request # unused
6566
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
66-
yield export_service_pb2.StreamExperimentsResponse(
67-
experiment_ids=["123", "456"]
68-
)
69-
yield export_service_pb2.StreamExperimentsResponse(
70-
experiment_ids=["789"]
71-
)
67+
yield _make_experiments_response(["123", "456"])
68+
yield _make_experiments_response(["789"])
7269

7370
def stream_experiment_data(request, **kwargs):
7471
self.assertEqual(kwargs["metadata"], grpc_util.version_metadata())
@@ -200,9 +197,7 @@ def test_rejects_dangerous_experiment_ids(self):
200197

201198
def stream_experiments(request, **kwargs):
202199
del request # unused
203-
yield export_service_pb2.StreamExperimentsResponse(
204-
experiment_ids=["../authorized_keys"]
205-
)
200+
yield _make_experiments_response(["../authorized_keys"])
206201

207202
mock_api_client.StreamExperiments = stream_experiments
208203

@@ -229,9 +224,7 @@ def test_fails_nicely_on_stream_experiment_data_timeout(self):
229224

230225
def stream_experiments(request, **kwargs):
231226
del request # unused
232-
yield export_service_pb2.StreamExperimentsResponse(
233-
experiment_ids=[experiment_id]
234-
)
227+
yield _make_experiments_response([experiment_id])
235228

236229
def stream_experiment_data(request, **kwargs):
237230
raise test_util.grpc_error(
@@ -260,9 +253,7 @@ def test_stream_experiment_data_passes_through_unexpected_exception(self):
260253

261254
def stream_experiments(request, **kwargs):
262255
del request # unused
263-
yield export_service_pb2.StreamExperimentsResponse(
264-
experiment_ids=[experiment_id]
265-
)
256+
yield _make_experiments_response([experiment_id])
266257

267258
def stream_experiment_data(request, **kwargs):
268259
del request # unused
@@ -288,11 +279,7 @@ def test_handles_outdir_with_no_slash(self):
288279
os.chdir(self.get_temp_dir())
289280
mock_api_client = self._create_mock_api_client()
290281
mock_api_client.StreamExperiments.return_value = iter(
291-
[
292-
export_service_pb2.StreamExperimentsResponse(
293-
experiment_ids=["123"]
294-
),
295-
]
282+
[_make_experiments_response(["123"])]
296283
)
297284
mock_api_client.StreamExperimentData.return_value = iter(
298285
[export_service_pb2.StreamExperimentDataResponse()]
@@ -335,6 +322,7 @@ def test_propagates_mkdir_errors(self):
335322

336323
class ListExperimentsTest(tb_test.TestCase):
337324
def test_experiment_ids_only(self):
325+
# Legacy server behavior; should raise an error.
338326
mock_api_client = _create_mock_api_client()
339327

340328
def stream_experiments(request, **kwargs):
@@ -347,45 +335,48 @@ def stream_experiments(request, **kwargs):
347335
)
348336

349337
mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
350-
gen = exporter_lib.list_experiments(mock_api_client)
351-
mock_api_client.StreamExperiments.assert_not_called()
352-
self.assertEqual(list(gen), ["123", "456", "789"])
338+
with self.assertRaises(RuntimeError) as cm:
339+
list(exporter_lib.list_experiments(mock_api_client))
340+
self.assertIn(repr(["123", "456"]), str(cm.exception))
353341

354342
def test_mixed_experiments_and_ids(self):
355343
mock_api_client = _create_mock_api_client()
356344

357345
def stream_experiments(request, **kwargs):
358346
del request # unused
359347

360-
# Should include `experiment_ids` when no `experiments` given.
361-
response = export_service_pb2.StreamExperimentsResponse()
362-
response.experiment_ids.append("123")
363-
response.experiment_ids.append("456")
364-
yield response
365-
366348
# Should ignore `experiment_ids` in the presence of `experiments`.
367349
response = export_service_pb2.StreamExperimentsResponse()
368350
response.experiment_ids.append("999") # will be omitted
369351
response.experiments.add(experiment_id="789")
370352
response.experiments.add(experiment_id="012")
371353
yield response
372354

373-
# Should include `experiments` even when no `experiment_ids` are given.
355+
mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
356+
gen = exporter_lib.list_experiments(mock_api_client)
357+
mock_api_client.StreamExperiments.assert_not_called()
358+
expected = [
359+
experiment_pb2.Experiment(experiment_id="789"),
360+
experiment_pb2.Experiment(experiment_id="012"),
361+
]
362+
self.assertEqual(list(gen), expected)
363+
364+
def test_experiments_only(self):
365+
mock_api_client = _create_mock_api_client()
366+
367+
def stream_experiments(request, **kwargs):
368+
del request # unused
374369
response = export_service_pb2.StreamExperimentsResponse()
375-
response.experiments.add(experiment_id="345")
376-
response.experiments.add(experiment_id="678")
370+
response.experiments.add(experiment_id="789", name="one")
371+
response.experiments.add(experiment_id="012", description="two")
377372
yield response
378373

379374
mock_api_client.StreamExperiments = mock.Mock(wraps=stream_experiments)
380375
gen = exporter_lib.list_experiments(mock_api_client)
381376
mock_api_client.StreamExperiments.assert_not_called()
382377
expected = [
383-
"123",
384-
"456",
385-
experiment_pb2.Experiment(experiment_id="789"),
386-
experiment_pb2.Experiment(experiment_id="012"),
387-
experiment_pb2.Experiment(experiment_id="345"),
388-
experiment_pb2.Experiment(experiment_id="678"),
378+
experiment_pb2.Experiment(experiment_id="789", name="one"),
379+
experiment_pb2.Experiment(experiment_id="012", description="two"),
389380
]
390381
self.assertEqual(list(gen), expected)
391382

tensorboard/uploader/uploader_main.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,6 @@ def execute(self, server_info, channel):
515515
count = 0
516516
for experiment in gen:
517517
count += 1
518-
if not isinstance(experiment, experiment_pb2.Experiment):
519-
url = server_info_lib.experiment_url(server_info, experiment)
520-
print(url)
521-
continue
522518
experiment_id = experiment.experiment_id
523519
url = server_info_lib.experiment_url(server_info, experiment_id)
524520
print(url)

0 commit comments

Comments
 (0)