Skip to content

Commit 4f0bafd

Browse files
committed
DataFrame: Flatten MultiIndex; Set include_wall_time=False (#3605)
* Motivation for features / changes * Improve DataFrame API based on the following user feedback * The MultiIndex DataFrame is hard to handle and not nice in terms of CSV round-trip compatibility * The wall_time column is usually not useful * Technical description of changes * When pivoting the DataFrame using `pivot_table()`, follow it with `reset_index()` to flatten the 2-level MultiIndex to a single level of index. * Then, call `.columns.name` to remove the `tag` columns name, which is an artifact of the tensorboard data model and not directly relevant to users. * Add `include_wall_time=False` kwarg to get_scalars() * Change the default value of `pivot` from `None` to `False` as per more standard and expected default value scheme.
1 parent 83bde95 commit 4f0bafd

File tree

4 files changed

+139
-57
lines changed

4 files changed

+139
-57
lines changed

tensorboard/data/experimental/base_experiment.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ class BaseExperiment(metaclass=abc.ABCMeta):
2828
# TODO(cais): Add list_scalar_tags().
2929

3030
@abc.abstractmethod
31-
def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
31+
def get_scalars(
32+
self,
33+
runs_filter=None,
34+
tags_filter=None,
35+
pivot=True,
36+
include_wall_time=False,
37+
):
3238
"""Export scalar data as a pandas.DataFrame.
3339
3440
Args:
@@ -40,6 +46,9 @@ def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
4046
`pivot_data()` method to a “wide” format wherein the tags of a
4147
given run and a given step are all collected in a single row.
4248
If not provided, defaults to `True`.
49+
include_wall_time: Include wall_time (timestamps in nanoseconds since
50+
the epoch in float64) as a column in the returned DataFrame.
51+
If not provided, defaults to `False`.
4352
4453
Returns:
4554
If `pivot` (default):

tensorboard/data/experimental/experiment_from_dev.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ def __init__(self, experiment_id, api_endpoint=None):
7070
self._experiment_id = experiment_id
7171
self._api_client = get_api_client(api_endpoint=api_endpoint)
7272

73-
def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
73+
def get_scalars(
74+
self,
75+
runs_filter=None,
76+
tags_filter=None,
77+
pivot=True,
78+
include_wall_time=False,
79+
):
7480
if runs_filter is not None:
7581
raise NotImplementedError(
7682
"runs_filter support for get_scalars() is not implemented yet."
@@ -79,7 +85,6 @@ def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
7985
raise NotImplementedError(
8086
"tags_filter support for get_scalars() is not implemented yet."
8187
)
82-
pivot = True if pivot is None else pivot
8388

8489
request = export_service_pb2.StreamExperimentDataRequest()
8590
request.experiment_id = self._experiment_id
@@ -107,33 +112,48 @@ def get_scalars(self, runs_filter=None, tags_filter=None, pivot=None):
107112
)
108113
values.extend(list(response.points.values))
109114

110-
dataframe = pandas.DataFrame(
111-
{
112-
"run": runs,
113-
"tag": tags,
114-
"step": steps,
115-
"wall_time": wall_times,
116-
"value": values,
117-
}
118-
)
115+
data = {
116+
"run": runs,
117+
"tag": tags,
118+
"step": steps,
119+
"value": values,
120+
}
121+
if include_wall_time:
122+
data["wall_time"] = wall_times
123+
dataframe = pandas.DataFrame(data)
119124
if pivot:
120125
dataframe = self._pivot_dataframe(dataframe)
121126
return dataframe
122127

123128
def _pivot_dataframe(self, dataframe):
124129
num_missing_0 = np.count_nonzero(dataframe.isnull().values)
125130
dataframe = dataframe.pivot_table(
126-
["value", "wall_time"], ["run", "step"], "tag",
131+
values=(
132+
["value", "wall_time"]
133+
if "wall_time" in dataframe.columns
134+
else "value"
135+
),
136+
index=["run", "step"],
137+
columns="tag",
138+
dropna=False,
127139
)
128140
num_missing_1 = np.count_nonzero(dataframe.isnull().values)
129141
if num_missing_1 > num_missing_0:
130142
raise ValueError(
131-
"pivoted DataFrame contains %d missing value(s). "
143+
"pivoted DataFrame contains missing value(s). "
132144
"This is likely due to two timeseries having different "
133145
"sets of steps in your experiment. "
134146
"You can avoid this error by calling `get_scalars()` with "
135147
"`pivot=False` to disable the DataFrame pivoting."
136148
)
149+
# `reset_index()` removes the MultiIndex structure of the pivoted
150+
# DataFrame. Before the call, the DataFrame consits of two levels
151+
# of index: "run" and "step". After the call, the index become a
152+
# single range index (e.g,. `dataframe[:2]` works).
153+
dataframe = dataframe.reset_index()
154+
# Remove the columns name "tag".
155+
dataframe.columns.name = None
156+
dataframe.columns.names = [None for name in dataframe.columns.names]
137157
return dataframe
138158

139159

tensorboard/data/experimental/experiment_from_dev_test.py

Lines changed: 90 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -79,45 +79,94 @@ def stream_experiment_data(request, **kwargs):
7979
lambda api_endpoint: mock_api_client,
8080
):
8181
experiment = experiment_from_dev.ExperimentFromDev("789")
82-
for pivot in (None, False):
83-
with self.subTest("pivot=%s" % pivot):
84-
dataframe = experiment.get_scalars(pivot=pivot)
85-
86-
expected = pandas.DataFrame(
87-
{
88-
"run": ["train"] * 20 + ["test"] * 20,
89-
"tag": (["accuracy"] * 10 + ["loss"] * 10) * 2,
90-
"step": list(np.arange(0, 10)) * 4,
91-
"wall_time": np.concatenate(
92-
[
93-
2.0 * np.arange(0, 10),
94-
1.0 * np.arange(0, 10),
95-
600.0 + 2.0 * np.arange(0, 10),
96-
600.0 + np.arange(0, 10),
97-
]
98-
),
99-
"value": np.concatenate(
100-
[
101-
1.0 / (10.0 - np.arange(0, 10)),
102-
1.0 / (1.0 + np.arange(0, 10)),
103-
-1.0 / (10.0 - np.arange(0, 10)),
104-
-1.0 / (1.0 + np.arange(0, 10)),
105-
]
106-
),
107-
}
108-
)
109-
110-
if pivot is None: # Default behavior: pivot_table.
111-
pandas.testing.assert_frame_equal(
112-
dataframe,
113-
expected.pivot_table(
114-
["value", "wall_time"], ["run", "step"], "tag"
115-
),
116-
check_names=True,
82+
for pivot in (True, False):
83+
for include_wall_time in (False, True):
84+
with self.subTest(
85+
"pivot=%s; include_wall_time=%s"
86+
% (pivot, include_wall_time)
87+
):
88+
dataframe = experiment.get_scalars(
89+
pivot=pivot, include_wall_time=include_wall_time
11790
)
118-
else: # pivot == False
91+
92+
if pivot:
93+
run_key = (
94+
("run", "") if include_wall_time else "run"
95+
)
96+
step_key = (
97+
("step", "") if include_wall_time else "step"
98+
)
99+
accuracy_value_key = (
100+
("value", "accuracy")
101+
if include_wall_time
102+
else "accuracy"
103+
)
104+
loss_value_key = (
105+
("value", "loss")
106+
if include_wall_time
107+
else "loss"
108+
)
109+
data = {
110+
run_key: ["test"] * 10 + ["train"] * 10,
111+
step_key: np.concatenate(
112+
[np.arange(0, 10), np.arange(0, 10)]
113+
),
114+
accuracy_value_key: np.concatenate(
115+
[
116+
-1.0 / (10.0 - np.arange(0, 10)),
117+
1.0 / (10.0 - np.arange(0, 10)),
118+
],
119+
),
120+
loss_value_key: np.concatenate(
121+
[
122+
-1.0 / (1.0 + np.arange(0, 10)),
123+
1.0 / (1.0 + np.arange(0, 10)),
124+
],
125+
),
126+
}
127+
if include_wall_time:
128+
data[
129+
("wall_time", "accuracy")
130+
] = np.concatenate(
131+
[
132+
600.0 + 2.0 * np.arange(0, 10),
133+
2.0 * np.arange(0, 10),
134+
]
135+
)
136+
data[("wall_time", "loss")] = np.concatenate(
137+
[
138+
600.0 + np.arange(0, 10),
139+
1.0 * np.arange(0, 10),
140+
]
141+
)
142+
expected = pandas.DataFrame(data)
143+
else: # No pivot_table.
144+
data = {
145+
"run": ["train"] * 20 + ["test"] * 20,
146+
"tag": (["accuracy"] * 10 + ["loss"] * 10) * 2,
147+
"step": list(np.arange(0, 10)) * 4,
148+
"value": np.concatenate(
149+
[
150+
1.0 / (10.0 - np.arange(0, 10)),
151+
1.0 / (1.0 + np.arange(0, 10)),
152+
-1.0 / (10.0 - np.arange(0, 10)),
153+
-1.0 / (1.0 + np.arange(0, 10)),
154+
]
155+
),
156+
}
157+
if include_wall_time:
158+
data["wall_time"] = np.concatenate(
159+
[
160+
2.0 * np.arange(0, 10),
161+
1.0 * np.arange(0, 10),
162+
600.0 + 2.0 * np.arange(0, 10),
163+
600.0 + np.arange(0, 10),
164+
]
165+
)
166+
expected = pandas.DataFrame(data)
167+
119168
pandas.testing.assert_frame_equal(
120-
dataframe, expected, check_names=True
169+
dataframe, expected, check_names=True,
121170
)
122171

123172
def test_get_scalars_with_pivot_table_with_missing_value(self):
@@ -156,7 +205,8 @@ def stream_experiment_data(request, **kwargs):
156205
experiment = experiment_from_dev.ExperimentFromDev("789")
157206
with self.assertRaisesRegexp(
158207
ValueError,
159-
r"missing value\(s\).*different sets of steps.*pivot=False",
208+
r"contains missing value\(s\).*different sets of "
209+
r"steps.*pivot=False",
160210
):
161211
experiment.get_scalars()
162212

@@ -193,12 +243,10 @@ def stream_experiment_data(request, **kwargs):
193243
expected = pandas.DataFrame(
194244
{
195245
"run": ["train"] * 2,
196-
"tag": ["batch_loss"] * 2,
197246
"step": [0, 1],
198-
"value": [np.nan, np.inf],
199-
"wall_time": [0.0, 10.0],
247+
"batch_loss": [np.nan, np.inf],
200248
}
201-
).pivot_table(["value", "wall_time"], ["run", "step"], "tag")
249+
)
202250
pandas.testing.assert_frame_equal(dataframe, expected, check_names=True)
203251

204252

tensorboard/data/experimental/test_binary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ def parse_args():
3737
default=None,
3838
help="Optional API endpoint used to override the default",
3939
)
40+
parser.add_argument(
41+
"--include_wall_time",
42+
action="store_true",
43+
help="Include wall_time column(s) in the DataFrame",
44+
)
4045
return parser.parse_args()
4146

4247

4348
def main(args):
4449
experiment = experiment_from_dev.ExperimentFromDev(
4550
args.experiment_id, api_endpoint=args.api_endpoint
4651
)
47-
dataframe = experiment.get_scalars()
52+
dataframe = experiment.get_scalars(include_wall_time=args.include_wall_time)
4853
print(dataframe)
4954

5055

0 commit comments

Comments
 (0)