Skip to content

Commit 44b80ed

Browse files
yikaiMetafacebook-github-bot
authored andcommitted
Create session id and use it in table 'Pytorch Elastic Tsm Log' (#953)
Summary: Pull Request resolved: #953 Please read the doc to understand why we create the session id: https://docs.google.com/document/d/1WJBrqSHrNIc9J1W_1PMIQPu11y2fV_hU36aeiBTgN90/edit Reviewed By: andywag Differential Revision: D62087199
1 parent 66733b7 commit 44b80ed

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

torchx/runner/events/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Dict, Optional, Type
2828

2929
from torchx.runner.events.handlers import get_logging_handler
30+
from torchx.util.session import get_session_id_or_create_new
3031

3132
from .api import SourceType, TorchxEvent # noqa F401
3233

@@ -136,7 +137,7 @@ def _generate_torchx_event(
136137
workspace: Optional[str] = None,
137138
) -> TorchxEvent:
138139
return TorchxEvent(
139-
session=app_id or "",
140+
session=get_session_id_or_create_new(),
140141
scheduler=scheduler,
141142
api=api,
142143
app_id=app_id,

torchx/runner/events/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class TorchxEvent:
2525
The class represents the event produced by ``torchx.runner`` api calls.
2626
2727
Arguments:
28-
session: Session id that was used to execute request.
28+
session: Session id of the current run
2929
scheduler: Scheduler that is used to execute request
3030
api: Api name
3131
app_id: Unique id that is set by the underlying scheduler

torchx/runner/events/test/lib_test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
TorchxEvent,
2020
)
2121

22+
SESSION_ID = "123"
23+
2224

2325
class TorchxEventLibTest(unittest.TestCase):
2426
def assert_event(
@@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None:
4446
def test_event_created(self) -> None:
4547
test_metadata = {"test_key": "test_value"}
4648
event = TorchxEvent(
47-
session="test_session",
49+
session=SESSION_ID,
4850
scheduler="test_scheduler",
4951
api="test_api",
5052
app_image="test_app_image",
5153
app_metadata=test_metadata,
5254
workspace="test_workspace",
5355
)
54-
self.assertEqual("test_session", event.session)
56+
self.assertEqual(SESSION_ID, event.session)
5557
self.assertEqual("test_scheduler", event.scheduler)
5658
self.assertEqual("test_api", event.api)
5759
self.assertEqual("test_app_image", event.app_image)
@@ -76,6 +78,7 @@ def test_event_deser(self) -> None:
7678

7779

7880
@patch("torchx.runner.events.record")
81+
@patch("torchx.runner.events.get_session_id_or_create_new")
7982
class LogEventTest(unittest.TestCase):
8083
def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None:
8184
self.assertEqual(expected.session, actual.session)
@@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non
8689
self.assertEqual(expected.workspace, actual.workspace)
8790
self.assertEqual(expected.app_metadata, actual.app_metadata)
8891

89-
def test_create_context(self, _) -> None:
92+
def test_create_context(
93+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
94+
) -> None:
95+
get_session_id_or_create_new_mock.return_value = SESSION_ID
9096
test_dict = {"test_key": "test_value"}
9197
cfg = json.dumps(test_dict)
9298
context = log_event(
@@ -99,7 +105,7 @@ def test_create_context(self, _) -> None:
99105
workspace="test_workspace",
100106
)
101107
expected_torchx_event = TorchxEvent(
102-
"test_app_id",
108+
SESSION_ID,
103109
"local",
104110
"test_call",
105111
"test_app_id",
@@ -111,7 +117,10 @@ def test_create_context(self, _) -> None:
111117

112118
self.assert_torchx_event(expected_torchx_event, context._torchx_event)
113119

114-
def test_record_event(self, record_mock: MagicMock) -> None:
120+
def test_record_event(
121+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
122+
) -> None:
123+
get_session_id_or_create_new_mock.return_value = SESSION_ID
115124
test_dict = {"test_key": "test_value"}
116125
cfg = json.dumps(test_dict)
117126
with log_event(
@@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None:
126135
pass
127136

128137
expected_torchx_event = TorchxEvent(
129-
"test_app_id",
138+
SESSION_ID,
130139
"local",
131140
"test_call",
132141
"test_app_id",
@@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None:
139148
)
140149
self.assert_torchx_event(expected_torchx_event, ctx._torchx_event)
141150

142-
def test_record_event_with_exception(self, record_mock: MagicMock) -> None:
151+
def test_record_event_with_exception(
152+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
153+
) -> None:
143154
cfg = json.dumps({"test_key": "test_value"})
144155
with self.assertRaises(RuntimeError):
145156
with log_event("test_call", "local", "test_app_id", cfg) as ctx:

torchx/runner/test/api_test.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_full_path(name: str) -> str:
5151
return os.path.join(os.path.dirname(__file__), "resource", name)
5252

5353

54-
@patch("torchx.runner.api.log_event")
54+
@patch("torchx.runner.events.record")
5555
class RunnerTest(TestWithTmpDir):
5656
def setUp(self) -> None:
5757
super().setUp()
@@ -104,7 +104,35 @@ def test_validate_invalid_replicas(self, _) -> None:
104104
with self.assertRaises(ValueError):
105105
runner.run(app, scheduler="local_dir")
106106

107-
def test_run(self, _) -> None:
107+
def test_session_id(self, record_mock: MagicMock) -> None:
108+
test_file = self.tmpdir / "test_file"
109+
110+
with self.get_runner() as runner:
111+
self.assertEqual(1, len(runner.scheduler_backends()))
112+
role = Role(
113+
name="touch",
114+
image=str(self.tmpdir),
115+
resource=resource.SMALL,
116+
entrypoint="touch.sh",
117+
args=[str(test_file)],
118+
)
119+
app = AppDef("name", roles=[role])
120+
121+
app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
122+
none_throws(runner.wait(app_handle_1, wait_interval=0.1))
123+
124+
app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
125+
none_throws(runner.wait(app_handle_2, wait_interval=0.1))
126+
127+
from torchx.util.session import CURRENT_SESSION_ID
128+
129+
self.assertIsNotNone(CURRENT_SESSION_ID)
130+
record_mock.assert_called()
131+
for i in range(record_mock.call_count):
132+
event = record_mock.call_args_list[i].args[0]
133+
self.assertEqual(event.session, CURRENT_SESSION_ID)
134+
135+
def test_run(self, record_mock: MagicMock) -> None:
108136
test_file = self.tmpdir / "test_file"
109137

110138
with self.get_runner() as runner:
@@ -121,8 +149,15 @@ def test_run(self, _) -> None:
121149
app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg)
122150
app_status = none_throws(runner.wait(app_handle, wait_interval=0.1))
123151
self.assertEqual(AppState.SUCCEEDED, app_status.state)
152+
from torchx.util.session import CURRENT_SESSION_ID
124153

125-
def test_dryrun(self, _) -> None:
154+
self.assertIsNotNone(CURRENT_SESSION_ID)
155+
record_mock.assert_called()
156+
for i in range(record_mock.call_count):
157+
event = record_mock.call_args_list[i].args[0]
158+
self.assertEqual(event.session, CURRENT_SESSION_ID)
159+
160+
def test_dryrun(self, record_mock: MagicMock) -> None:
126161
scheduler_mock = MagicMock()
127162
scheduler_mock.run_opts.return_value.resolve.return_value = {
128163
**self.cfg,
@@ -145,6 +180,13 @@ def test_dryrun(self, _) -> None:
145180
app, {**self.cfg, "foo": "bar"}
146181
)
147182
scheduler_mock._validate.assert_called_once()
183+
from torchx.util.session import CURRENT_SESSION_ID
184+
185+
self.assertIsNotNone(CURRENT_SESSION_ID)
186+
record_mock.assert_called()
187+
for i in range(record_mock.call_count):
188+
event = record_mock.call_args_list[i].args[0]
189+
self.assertEqual(event.session, CURRENT_SESSION_ID)
148190

149191
def test_dryrun_env_variables(self, _) -> None:
150192
scheduler_mock = MagicMock()

torchx/util/session.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import uuid
11+
from typing import Optional
12+
13+
CURRENT_SESSION_ID: Optional[str] = None
14+
15+
16+
def get_session_id_or_create_new() -> str:
17+
"""
18+
Returns the current session ID, or creates a new one if none exists.
19+
The session ID remains the same as long as it is in the same process.
20+
"""
21+
global CURRENT_SESSION_ID
22+
if CURRENT_SESSION_ID:
23+
return CURRENT_SESSION_ID
24+
session_id = str(uuid.uuid4())
25+
CURRENT_SESSION_ID = session_id
26+
return session_id

0 commit comments

Comments
 (0)