Skip to content

Commit fd46c42

Browse files
committed
uploader: request ServerInfo from frontend (#2879)
Summary: This commit integrates the new `ServerInfo` RPC with the uploader. It’s not currently enabled by default: the current behavior is the same as the existing behavior, except that experiment URLs now properly have a trailing slash. We’ll soon remove the hard-coded API backend endpoint behavior to enable this by default. Test Plan: Running a test frontend and a test backend, we observe the following behavior with different arguments: | `--origin` | `--api_endpoint` | → | URL origin | Backend | |------------|------------------|---|------------|---------| | empty | empty | | prod | prod | | empty | prod | | prod | prod | | empty | test | | prod | test | | test | empty | | test | test | | test | test | | test | test | | test | prod | | test | prod | Here, “test” in the `--origin` column is like `http://localhost:8080`, and “test” in the `--api_endpoint` column is like `localhost:10000`. Note that the no-argument case is equivalent to the explicitly-empty argument case because both arguments have empty default values. Explicitly specifying `--origin https://tensorboard.dev`, with any value of `--api_endpoint`, fails with “Corrupt response from backend” because server-side support has not yet been rolled out. This is expected. Specifying `--origin http://localhost:0` or any other unreachable host fails with `ECONNREFUSED` and a nice message. My test frontend is configured to reject clients below version 2.0.0 and warn on clients below version 2.0.1. Changing the local `version.py` to `2.0.0a0` or `2.0.1a0` exercises these cases. Finally, double-checked that building the Pip package, installing it, and running `tensorboard dev list` properly uses the production backend and prints URLs that resolve to the production frontend. wchargin-branch: uploader-serverinfo-request
1 parent a554538 commit fd46c42

File tree

6 files changed

+101
-19
lines changed

6 files changed

+101
-19
lines changed

tensorboard/uploader/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ py_library(
5656
":auth",
5757
":dev_creds",
5858
":exporter_lib",
59+
":server_info",
5960
":uploader_lib",
6061
"//tensorboard:expect_absl_app_installed",
6162
"//tensorboard:expect_absl_flags_argparse_flags_installed",

tensorboard/uploader/server_info.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def create_server_info(frontend_origin, api_endpoint):
9797
return result
9898

9999

100+
def experiment_url(server_info, experiment_id):
101+
"""Formats a URL that will resolve to the provided experiment.
102+
103+
Args:
104+
server_info: A `server_info_pb2.ServerInfoResponse` message.
105+
experiment_id: A string; the ID of the experiment to link to.
106+
107+
Returns:
108+
A URL resolving to the given experiment, as a string.
109+
"""
110+
url_format = server_info.url_format
111+
return url_format.template.replace(url_format.id_placeholder, experiment_id)
112+
113+
100114
class CommunicationError(RuntimeError):
101115
"""Raised upon failure to communicate with the server."""
102116

tensorboard/uploader/server_info_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,17 @@ def test(self):
147147
self.assertEqual(actual_url, expected_url)
148148

149149

150+
class ExperimentUrlTest(tb_test.TestCase):
151+
"""Tests for `experiment_url`."""
152+
153+
def test(self):
154+
info = server_info_pb2.ServerInfoResponse()
155+
info.url_format.template = "https://unittest.tensorboard.dev/x/???"
156+
info.url_format.id_placeholder = "???"
157+
actual = server_info.experiment_url(info, "123")
158+
self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123")
159+
160+
150161
def _localhost():
151162
"""Gets family and nodename for a loopback address."""
152163
s = socket

tensorboard/uploader/uploader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ def __init__(self, writer_client, logdir, rate_limiter=None):
8989
self._logdir, directory_loader_factory)
9090

9191
def create_experiment(self):
92-
"""Creates an Experiment for this upload session and returns the URL."""
92+
"""Creates an Experiment for this upload session and returns the ID."""
9393
logger.info("Creating experiment")
9494
request = write_service_pb2.CreateExperimentRequest()
9595
response = grpc_util.call_with_retries(self._api.CreateExperiment, request)
9696
self._request_builder = _RequestBuilder(response.experiment_id)
97-
return response.url
97+
return response.experiment_id
9898

9999
def start_uploading(self):
100100
"""Blocks forever to continuously upload data from the logdir.

tensorboard/uploader/uploader_main.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from tensorboard.uploader.proto import write_service_pb2_grpc
3636
from tensorboard.uploader import auth
3737
from tensorboard.uploader import exporter as exporter_lib
38+
from tensorboard.uploader import server_info as server_info_lib
3839
from tensorboard.uploader import uploader as uploader_lib
40+
from tensorboard.uploader.proto import server_info_pb2
3941
from tensorboard import program
4042
from tensorboard.plugins import base_plugin
4143

@@ -65,6 +67,11 @@
6567
_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth'
6668
_AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE'
6769

70+
_DEFAULT_ORIGIN = "https://tensorboard.dev"
71+
# Compatibility measure until server-side /api/uploader support is
72+
# rolled out and stable.
73+
_HARDCODED_API_ENDPOINT = "api.tensorboard.dev:443"
74+
6875

6976
def _prompt_for_user_ack(intent):
7077
"""Prompts for user consent, exiting the program if they decline."""
@@ -91,10 +98,19 @@ def _define_flags(parser):
9198
subparsers = parser.add_subparsers()
9299

93100
parser.add_argument(
94-
'--endpoint',
101+
'--origin',
95102
type=str,
96-
default='api.tensorboard.dev:443',
97-
help='URL for the API server accepting write requests.')
103+
default='',
104+
help='Experimental. Origin for TensorBoard.dev service to which '
105+
'to connect. If not set, defaults to %r.' % _DEFAULT_ORIGIN)
106+
107+
parser.add_argument(
108+
'--api_endpoint',
109+
type=str,
110+
default='',
111+
help='Experimental. Direct URL for the API server accepting '
112+
'write requests. If set, will skip initial server handshake '
113+
'unless `--origin` is also set.')
98114

99115
parser.add_argument(
100116
'--grpc_creds_type',
@@ -222,15 +238,26 @@ def _run(flags):
222238
msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type
223239
raise base_plugin.FlagsError(msg)
224240

241+
try:
242+
server_info = _get_server_info(flags)
243+
except server_info_lib.CommunicationError as e:
244+
_die(str(e))
245+
_handle_server_info(server_info)
246+
247+
if not server_info.api_server.endpoint:
248+
logging.error('Server info response: %s', server_info)
249+
_die('Internal error: frontend did not specify an API server')
225250
composite_channel_creds = grpc.composite_channel_credentials(
226251
channel_creds, auth.id_token_call_credentials(credentials))
227252

228253
# TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
229254
# logdir exists to open channel.
230255
channel = grpc.secure_channel(
231-
flags.endpoint, composite_channel_creds, options=channel_options)
256+
server_info.api_server.endpoint,
257+
composite_channel_creds,
258+
options=channel_options)
232259
with channel:
233-
intent.execute(channel)
260+
intent.execute(server_info, channel)
234261

235262

236263
@six.add_metaclass(abc.ABCMeta)
@@ -254,10 +281,11 @@ def get_ack_message_body(self):
254281
pass
255282

256283
@abc.abstractmethod
257-
def execute(self, channel):
284+
def execute(self, server_info, channel):
258285
"""Carries out this intent with the specified gRPC channel.
259286
260287
Args:
288+
server_info: A `server_info_pb2.ServerInfoResponse` value.
261289
channel: A connected gRPC channel whose server provides the TensorBoard
262290
reader and writer services.
263291
"""
@@ -271,7 +299,7 @@ def get_ack_message_body(self):
271299
"""Must not be called."""
272300
raise AssertionError('No user ack needed to revoke credentials')
273301

274-
def execute(self, channel):
302+
def execute(self, server_info, channel):
275303
"""Execute handled specially by `main`. Must not be called."""
276304
raise AssertionError('_AuthRevokeIntent should not be directly executed')
277305

@@ -296,7 +324,7 @@ def __init__(self, experiment_id):
296324
def get_ack_message_body(self):
297325
return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id)
298326

299-
def execute(self, channel):
327+
def execute(self, server_info, channel):
300328
api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
301329
experiment_id = self.experiment_id
302330
if not experiment_id:
@@ -329,14 +357,13 @@ class _ListIntent(_Intent):
329357
def get_ack_message_body(self):
330358
return self._MESSAGE
331359

332-
def execute(self, channel):
360+
def execute(self, server_info, channel):
333361
api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
334362
gen = exporter_lib.list_experiments(api_client)
335363
count = 0
336364
for experiment_id in gen:
337365
count += 1
338-
# TODO(@wchargin): Once #2879 is in, remove this hard-coded URL pattern.
339-
url = 'https://tensorboard.dev/experiment/%s/' % experiment_id
366+
url = server_info_lib.experiment_url(server_info, experiment_id)
340367
print(url)
341368
sys.stdout.flush()
342369
if not count:
@@ -366,10 +393,11 @@ def __init__(self, logdir):
366393
def get_ack_message_body(self):
367394
return self._MESSAGE_TEMPLATE.format(logdir=self.logdir)
368395

369-
def execute(self, channel):
396+
def execute(self, server_info, channel):
370397
api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
371398
uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir)
372-
url = uploader.create_experiment()
399+
experiment_id = uploader.create_experiment()
400+
url = server_info_lib.experiment_url(server_info, experiment_id)
373401
print("Upload started and will continue reading any new data as it's added")
374402
print("to the logdir. To stop uploading, press Ctrl-C.")
375403
print("View your TensorBoard live at: %s" % url)
@@ -407,7 +435,7 @@ def __init__(self, output_dir):
407435
def get_ack_message_body(self):
408436
return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir)
409437

410-
def execute(self, channel):
438+
def execute(self, server_info, channel):
411439
api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
412440
outdir = self.output_dir
413441
try:
@@ -476,6 +504,34 @@ def _get_intent(flags):
476504
raise AssertionError('Unknown subcommand %r' % (cmd,))
477505

478506

507+
def _get_server_info(flags):
508+
origin = flags.origin or _DEFAULT_ORIGIN
509+
if not flags.origin:
510+
# Temporary fallback to hardcoded API endpoint when not specified.
511+
api_endpoint = flags.api_endpoint or _HARDCODED_API_ENDPOINT
512+
return server_info_lib.create_server_info(origin, api_endpoint)
513+
server_info = server_info_lib.fetch_server_info(origin)
514+
# Override with any API server explicitly specified on the command
515+
# line, but only if the server accepted our initial handshake.
516+
if flags.api_endpoint and server_info.api_server.endpoint:
517+
server_info.api_server.endpoint = flags.api_endpoint
518+
return server_info
519+
520+
521+
def _handle_server_info(info):
522+
compat = info.compatibility
523+
if compat.verdict == server_info_pb2.VERDICT_WARN:
524+
sys.stderr.write('Warning [from server]: %s\n' % compat.details)
525+
sys.stderr.flush()
526+
elif compat.verdict == server_info_pb2.VERDICT_ERROR:
527+
_die('Error [from server]: %s' % compat.details)
528+
else:
529+
# OK or unknown; assume OK.
530+
if compat.details:
531+
sys.stderr.write('%s\n' % compat.details)
532+
sys.stderr.flush()
533+
534+
479535
def _die(message):
480536
sys.stderr.write('%s\n' % (message,))
481537
sys.stderr.flush()

tensorboard/uploader/uploader_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ def _create_mock_client(self):
6161
stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel)
6262
mock_client = mock.create_autospec(stub)
6363
fake_exp_response = write_service_pb2.CreateExperimentResponse(
64-
experiment_id="123", url="https://example.com/123")
64+
experiment_id="123", url="should not be used!")
6565
mock_client.CreateExperiment.return_value = fake_exp_response
6666
return mock_client
6767

6868
def test_create_experiment(self):
6969
logdir = "/logs/foo"
7070
mock_client = self._create_mock_client()
7171
uploader = uploader_lib.TensorBoardUploader(mock_client, logdir)
72-
url = uploader.create_experiment()
73-
self.assertEqual(url, "https://example.com/123")
72+
eid = uploader.create_experiment()
73+
self.assertEqual(eid, "123")
7474

7575
def test_start_uploading_without_create_experiment_fails(self):
7676
mock_client = self._create_mock_client()

0 commit comments

Comments
 (0)