|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""A data provider that dispatches to sub-providers based on prefix.""" |
| 16 | + |
| 17 | +import base64 |
| 18 | +import json |
| 19 | + |
| 20 | +from tensorboard import errors |
| 21 | +from tensorboard.data import provider |
| 22 | + |
| 23 | + |
| 24 | +# Separator between prefix and sub-ID. |
| 25 | +_SEPARATOR = ":" |
| 26 | + |
| 27 | + |
| 28 | +class DispatchingDataProvider(provider.DataProvider): |
| 29 | + """Data provider that dispatches to sub-providers based on prefix. |
| 30 | +
|
| 31 | + If you have one data provider that talks to the Foo service with IDs |
| 32 | + like `123` and another that talks to the Bar service with IDs like |
| 33 | + `3a28213a`, then this data provider lets you talk to both of them |
| 34 | + with IDs like `foo:123` and `bar:3a28213a`, respectively. The part |
| 35 | + before the colon is the *prefix*, and identifies the sub-provider to |
| 36 | + which to dispatch. The part after the colon is the *sub-ID*, and is |
| 37 | + passed verbatim to the sub-provider. The sub-ID may contain any |
| 38 | + characters, including colons, so the sub-providers may themselves be |
| 39 | + hierarchical. |
| 40 | +
|
| 41 | + Optionally, an `unprefixed_provider` can be specified as default in |
| 42 | + the case where an experiment ID does not contain a colon. Note that |
| 43 | + this is not used as a fallback when the prefix is simply not one of |
| 44 | + the registered prefixes; that will always be an error. |
| 45 | + """ |
| 46 | + |
| 47 | + # Implementation note: this data provider provides a simple |
| 48 | + # pass-through for most methods, but has extra logic for methods |
| 49 | + # related to blob keys, where we need to annotate or extract the |
| 50 | + # associated sub-provider. |
| 51 | + |
| 52 | + def __init__(self, providers, unprefixed_provider=None): |
| 53 | + """Initialize a `DispatchingDataProvider`. |
| 54 | +
|
| 55 | + Args: |
| 56 | + providers: Dict mapping prefix (`str`) to sub-provider |
| 57 | + instance (`provider.DataProvider`). Keys will appear in |
| 58 | + experiment IDs and so must be URL-safe. |
| 59 | + unprefixed_provider: Optional `provider.DataProvider` instance |
| 60 | + to use with experiment IDs that do not have a prefix. |
| 61 | +
|
| 62 | + Raises: |
| 63 | + ValueError: If any of the provider keys contains a colon, |
| 64 | + which would make it impossible to match. |
| 65 | + """ |
| 66 | + self._providers = dict(providers) |
| 67 | + invalid_names = sorted(k for k in self._providers if _SEPARATOR in k) |
| 68 | + if invalid_names: |
| 69 | + raise ValueError("Invalid provider key(s): %r" % invalid_names) |
| 70 | + self._unprefixed_provider = unprefixed_provider |
| 71 | + |
| 72 | + def _parse_eid(self, experiment_id): |
| 73 | + """Parse an experiment ID into prefix, sub-ID, and sub-provider. |
| 74 | +
|
| 75 | + The returned prefix may be `None` if this instance has an |
| 76 | + unprefixed data provider registered. If the experiment ID is |
| 77 | + invalid, this method may raise an `errors.NotFoundError`. |
| 78 | + """ |
| 79 | + parts = experiment_id.split(_SEPARATOR, 1) |
| 80 | + if len(parts) == 1: |
| 81 | + if self._unprefixed_provider is None: |
| 82 | + raise errors.NotFoundError( |
| 83 | + "No data provider found for unprefixed experiment ID: %r" |
| 84 | + % experiment_id |
| 85 | + ) |
| 86 | + return (None, experiment_id, self._unprefixed_provider) |
| 87 | + (prefix, sub_eid) = parts |
| 88 | + sub_provider = self._providers.get(prefix) |
| 89 | + if sub_provider is None: |
| 90 | + raise errors.NotFoundError( |
| 91 | + "Unknown prefix in experiment ID: %r" % experiment_id |
| 92 | + ) |
| 93 | + return (prefix, sub_eid, sub_provider) |
| 94 | + |
| 95 | + def _simple_delegate(get_method): |
| 96 | + """Dispatch on experiment ID, forwarding args and result unchanged.""" |
| 97 | + |
| 98 | + def wrapper(self, *args, experiment_id, **kwargs): |
| 99 | + (_, sub_eid, sub_provider) = self._parse_eid(experiment_id) |
| 100 | + method = get_method(sub_provider) |
| 101 | + return method(*args, experiment_id=sub_eid, **kwargs) |
| 102 | + |
| 103 | + return wrapper |
| 104 | + |
| 105 | + data_location = _simple_delegate(lambda p: p.data_location) |
| 106 | + experiment_metadata = _simple_delegate(lambda p: p.experiment_metadata) |
| 107 | + list_plugins = _simple_delegate(lambda p: p.list_plugins) |
| 108 | + list_runs = _simple_delegate(lambda p: p.list_runs) |
| 109 | + list_scalars = _simple_delegate(lambda p: p.list_scalars) |
| 110 | + read_scalars = _simple_delegate(lambda p: p.read_scalars) |
| 111 | + list_tensors = _simple_delegate(lambda p: p.list_tensors) |
| 112 | + read_tensors = _simple_delegate(lambda p: p.read_tensors) |
| 113 | + list_blob_sequences = _simple_delegate(lambda p: p.list_blob_sequences) |
| 114 | + |
| 115 | + def read_blob_sequences(self, *args, experiment_id, **kwargs): |
| 116 | + (prefix, sub_eid, sub_provider) = self._parse_eid(experiment_id) |
| 117 | + result = sub_provider.read_blob_sequences( |
| 118 | + *args, experiment_id=sub_eid, **kwargs |
| 119 | + ) |
| 120 | + for tag_to_data in result.values(): |
| 121 | + for (tag, old_data) in tag_to_data.items(): |
| 122 | + new_data = [ |
| 123 | + provider.BlobSequenceDatum( |
| 124 | + step=d.step, |
| 125 | + wall_time=d.wall_time, |
| 126 | + values=_convert_blob_references(prefix, d.values), |
| 127 | + ) |
| 128 | + for d in old_data |
| 129 | + ] |
| 130 | + tag_to_data[tag] = new_data |
| 131 | + return result |
| 132 | + |
| 133 | + def read_blob(self, ctx, blob_key): |
| 134 | + (prefix, sub_key) = _decode_blob_key(blob_key) |
| 135 | + if prefix is None: |
| 136 | + if self._unprefixed_provider is None: |
| 137 | + raise errors.NotFoundError( |
| 138 | + "Invalid blob key: no unprefixed provider" |
| 139 | + ) |
| 140 | + return self._unprefixed_provider.read_blob(ctx, blob_key=sub_key) |
| 141 | + sub_provider = self._providers.get(prefix) |
| 142 | + if sub_provider is None: |
| 143 | + raise errors.NotFoundError( |
| 144 | + "Invalid blob key: no such provider: %r; have: %r" |
| 145 | + % (prefix, sorted(self._providers)) |
| 146 | + ) |
| 147 | + return sub_provider.read_blob(ctx, blob_key=sub_key) |
| 148 | + |
| 149 | + |
| 150 | +def _convert_blob_references(prefix, references): |
| 151 | + """Encode all blob keys in a list of blob references. |
| 152 | +
|
| 153 | + Args: |
| 154 | + prefix: The prefix of the sub-provider that generated the sub-key, |
| 155 | + or `None` if this was generated by the unprefixed provider. |
| 156 | + references: A list of `provider.BlobReference`s emitted by a |
| 157 | + sub-provider. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + A new list of `provider.BlobReference`s whose blob keys have been |
| 161 | + encoded per `_encode_blob_key`. |
| 162 | + """ |
| 163 | + return [ |
| 164 | + provider.BlobReference( |
| 165 | + blob_key=_encode_blob_key(prefix, r.blob_key), url=r.url, |
| 166 | + ) |
| 167 | + for r in references |
| 168 | + ] |
| 169 | + |
| 170 | + |
| 171 | +def _encode_blob_key(prefix, sub_key): |
| 172 | + """Encode a blob key from prefix (optional) and sub-key. |
| 173 | +
|
| 174 | + Args: |
| 175 | + prefix: The prefix of the sub-provider that generated the sub-key, |
| 176 | + or `None` if this was generated by the unprefixed provider. |
| 177 | + sub_key: The opaque key from the sub-provider. |
| 178 | +
|
| 179 | + Returns: |
| 180 | + A string encoding `prefix` and `sub_key` injectively. |
| 181 | + """ |
| 182 | + payload = [prefix, sub_key] |
| 183 | + json_str = json.dumps(payload, separators=(",", ":")) |
| 184 | + b64_str = base64.urlsafe_b64encode(json_str.encode("ascii")).decode("ascii") |
| 185 | + return b64_str.rstrip("=") |
| 186 | + |
| 187 | + |
| 188 | +def _decode_blob_key(key): |
| 189 | + """Decode a prefix (optional) and sub-key from a blob key. |
| 190 | +
|
| 191 | + Left inverse of `_encode_blob_key`. |
| 192 | +
|
| 193 | + Args: |
| 194 | + key: A blob key in the form returned by `_encode_blob_key`. |
| 195 | +
|
| 196 | + Returns; |
| 197 | + A tuple `(prefix, sub_key)`, where `prefix` is either `None` or a |
| 198 | + sub-provider prefix, and `sub_key` is an opaque key from a |
| 199 | + sub-provider. |
| 200 | +
|
| 201 | + Raises: |
| 202 | + errors.NotFoundError: If `key` is invalid and has no preimage. |
| 203 | + """ |
| 204 | + failure = errors.NotFoundError("Invalid blob key: %r" % key) |
| 205 | + |
| 206 | + b64_str = key + "==" # ensure adequate padding (overpadding is okay) |
| 207 | + json_str = base64.urlsafe_b64decode(b64_str).decode("ascii") |
| 208 | + payload = json.loads(json_str) |
| 209 | + if not isinstance(payload, list) or len(payload) != 2: |
| 210 | + raise failure |
| 211 | + (prefix, sub_key) = payload |
| 212 | + if not (prefix is None or isinstance(prefix, str)): |
| 213 | + raise failure |
| 214 | + if not isinstance(sub_key, str): |
| 215 | + raise failure |
| 216 | + return (prefix, sub_key) |
0 commit comments