43
43
from tensorboard .compat .proto import summary_pb2
44
44
45
45
46
+ def _make_experiments_response (eids ):
47
+ """Make a `StreamExperimentsResponse` with experiments with only IDs."""
48
+ response = export_service_pb2 .StreamExperimentsResponse ()
49
+ for eid in eids :
50
+ response .experiments .add (experiment_id = eid )
51
+ return response
52
+
53
+
46
54
class TensorBoardExporterTest (tb_test .TestCase ):
47
55
def _create_mock_api_client (self ):
48
56
return _create_mock_api_client ()
49
57
50
- def _make_experiments_response (self , eids ):
51
- return export_service_pb2 .StreamExperimentsResponse (experiment_ids = eids )
52
-
53
58
def test_e2e_success_case (self ):
54
59
mock_api_client = self ._create_mock_api_client ()
55
60
mock_api_client .StreamExperiments .return_value = iter (
56
- [
57
- export_service_pb2 .StreamExperimentsResponse (
58
- experiment_ids = ["789" ]
59
- ),
60
- ]
61
+ [_make_experiments_response (["789" ])]
61
62
)
62
63
63
64
def stream_experiments (request , ** kwargs ):
64
65
del request # unused
65
66
self .assertEqual (kwargs ["metadata" ], grpc_util .version_metadata ())
66
- yield export_service_pb2 .StreamExperimentsResponse (
67
- experiment_ids = ["123" , "456" ]
68
- )
69
- yield export_service_pb2 .StreamExperimentsResponse (
70
- experiment_ids = ["789" ]
71
- )
67
+ yield _make_experiments_response (["123" , "456" ])
68
+ yield _make_experiments_response (["789" ])
72
69
73
70
def stream_experiment_data (request , ** kwargs ):
74
71
self .assertEqual (kwargs ["metadata" ], grpc_util .version_metadata ())
@@ -200,9 +197,7 @@ def test_rejects_dangerous_experiment_ids(self):
200
197
201
198
def stream_experiments (request , ** kwargs ):
202
199
del request # unused
203
- yield export_service_pb2 .StreamExperimentsResponse (
204
- experiment_ids = ["../authorized_keys" ]
205
- )
200
+ yield _make_experiments_response (["../authorized_keys" ])
206
201
207
202
mock_api_client .StreamExperiments = stream_experiments
208
203
@@ -229,9 +224,7 @@ def test_fails_nicely_on_stream_experiment_data_timeout(self):
229
224
230
225
def stream_experiments (request , ** kwargs ):
231
226
del request # unused
232
- yield export_service_pb2 .StreamExperimentsResponse (
233
- experiment_ids = [experiment_id ]
234
- )
227
+ yield _make_experiments_response ([experiment_id ])
235
228
236
229
def stream_experiment_data (request , ** kwargs ):
237
230
raise test_util .grpc_error (
@@ -260,9 +253,7 @@ def test_stream_experiment_data_passes_through_unexpected_exception(self):
260
253
261
254
def stream_experiments (request , ** kwargs ):
262
255
del request # unused
263
- yield export_service_pb2 .StreamExperimentsResponse (
264
- experiment_ids = [experiment_id ]
265
- )
256
+ yield _make_experiments_response ([experiment_id ])
266
257
267
258
def stream_experiment_data (request , ** kwargs ):
268
259
del request # unused
@@ -288,11 +279,7 @@ def test_handles_outdir_with_no_slash(self):
288
279
os .chdir (self .get_temp_dir ())
289
280
mock_api_client = self ._create_mock_api_client ()
290
281
mock_api_client .StreamExperiments .return_value = iter (
291
- [
292
- export_service_pb2 .StreamExperimentsResponse (
293
- experiment_ids = ["123" ]
294
- ),
295
- ]
282
+ [_make_experiments_response (["123" ])]
296
283
)
297
284
mock_api_client .StreamExperimentData .return_value = iter (
298
285
[export_service_pb2 .StreamExperimentDataResponse ()]
@@ -335,6 +322,7 @@ def test_propagates_mkdir_errors(self):
335
322
336
323
class ListExperimentsTest (tb_test .TestCase ):
337
324
def test_experiment_ids_only (self ):
325
+ # Legacy server behavior; should raise an error.
338
326
mock_api_client = _create_mock_api_client ()
339
327
340
328
def stream_experiments (request , ** kwargs ):
@@ -347,45 +335,48 @@ def stream_experiments(request, **kwargs):
347
335
)
348
336
349
337
mock_api_client .StreamExperiments = mock .Mock (wraps = stream_experiments )
350
- gen = exporter_lib . list_experiments ( mock_api_client )
351
- mock_api_client . StreamExperiments . assert_not_called ( )
352
- self .assertEqual ( list ( gen ), ["123" , "456" , "789" ] )
338
+ with self . assertRaises ( RuntimeError ) as cm :
339
+ list ( exporter_lib . list_experiments ( mock_api_client ) )
340
+ self .assertIn ( repr ( ["123" , "456" ]), str ( cm . exception ) )
353
341
354
342
def test_mixed_experiments_and_ids (self ):
355
343
mock_api_client = _create_mock_api_client ()
356
344
357
345
def stream_experiments (request , ** kwargs ):
358
346
del request # unused
359
347
360
- # Should include `experiment_ids` when no `experiments` given.
361
- response = export_service_pb2 .StreamExperimentsResponse ()
362
- response .experiment_ids .append ("123" )
363
- response .experiment_ids .append ("456" )
364
- yield response
365
-
366
348
# Should ignore `experiment_ids` in the presence of `experiments`.
367
349
response = export_service_pb2 .StreamExperimentsResponse ()
368
350
response .experiment_ids .append ("999" ) # will be omitted
369
351
response .experiments .add (experiment_id = "789" )
370
352
response .experiments .add (experiment_id = "012" )
371
353
yield response
372
354
373
- # Should include `experiments` even when no `experiment_ids` are given.
355
+ mock_api_client .StreamExperiments = mock .Mock (wraps = stream_experiments )
356
+ gen = exporter_lib .list_experiments (mock_api_client )
357
+ mock_api_client .StreamExperiments .assert_not_called ()
358
+ expected = [
359
+ experiment_pb2 .Experiment (experiment_id = "789" ),
360
+ experiment_pb2 .Experiment (experiment_id = "012" ),
361
+ ]
362
+ self .assertEqual (list (gen ), expected )
363
+
364
+ def test_experiments_only (self ):
365
+ mock_api_client = _create_mock_api_client ()
366
+
367
+ def stream_experiments (request , ** kwargs ):
368
+ del request # unused
374
369
response = export_service_pb2 .StreamExperimentsResponse ()
375
- response .experiments .add (experiment_id = "345 " )
376
- response .experiments .add (experiment_id = "678 " )
370
+ response .experiments .add (experiment_id = "789" , name = "one " )
371
+ response .experiments .add (experiment_id = "012" , description = "two " )
377
372
yield response
378
373
379
374
mock_api_client .StreamExperiments = mock .Mock (wraps = stream_experiments )
380
375
gen = exporter_lib .list_experiments (mock_api_client )
381
376
mock_api_client .StreamExperiments .assert_not_called ()
382
377
expected = [
383
- "123" ,
384
- "456" ,
385
- experiment_pb2 .Experiment (experiment_id = "789" ),
386
- experiment_pb2 .Experiment (experiment_id = "012" ),
387
- experiment_pb2 .Experiment (experiment_id = "345" ),
388
- experiment_pb2 .Experiment (experiment_id = "678" ),
378
+ experiment_pb2 .Experiment (experiment_id = "789" , name = "one" ),
379
+ experiment_pb2 .Experiment (experiment_id = "012" , description = "two" ),
389
380
]
390
381
self .assertEqual (list (gen ), expected )
391
382
0 commit comments