Skip to content

Commit 68c140f

Browse files
yikaiMetafacebook-github-bot
authored andcommitted
Create session id and use it in table 'Pytorch Elastic Tsm Log' (#953)
Summary: Pull Request resolved: #953 Please read the doc to understand why we create the session id: https://docs.google.com/document/d/1WJBrqSHrNIc9J1W_1PMIQPu11y2fV_hU36aeiBTgN90/edit Differential Revision: D62087199
1 parent 66733b7 commit 68c140f

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

torchx/runner/events/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from typing import Dict, Optional, Type
2828

2929
from torchx.runner.events.handlers import get_logging_handler
30+
from torchx.util.session import get_session_id_or_create_new
3031

3132
from .api import SourceType, TorchxEvent # noqa F401
3233

@@ -136,7 +137,7 @@ def _generate_torchx_event(
136137
workspace: Optional[str] = None,
137138
) -> TorchxEvent:
138139
return TorchxEvent(
139-
session=app_id or "",
140+
session=get_session_id_or_create_new(),
140141
scheduler=scheduler,
141142
api=api,
142143
app_id=app_id,

torchx/runner/events/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class TorchxEvent:
2525
The class represents the event produced by ``torchx.runner`` api calls.
2626
2727
Arguments:
28-
session: Session id that was used to execute request.
28+
session: Session id of the current run
2929
scheduler: Scheduler that is used to execute request
3030
api: Api name
3131
app_id: Unique id that is set by the underlying scheduler

torchx/runner/events/test/lib_test.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
TorchxEvent,
2020
)
2121

22+
SESSION_ID = "123"
23+
2224

2325
class TorchxEventLibTest(unittest.TestCase):
2426
def assert_event(
@@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None:
4446
def test_event_created(self) -> None:
4547
test_metadata = {"test_key": "test_value"}
4648
event = TorchxEvent(
47-
session="test_session",
49+
session=SESSION_ID,
4850
scheduler="test_scheduler",
4951
api="test_api",
5052
app_image="test_app_image",
5153
app_metadata=test_metadata,
5254
workspace="test_workspace",
5355
)
54-
self.assertEqual("test_session", event.session)
56+
self.assertEqual(SESSION_ID, event.session)
5557
self.assertEqual("test_scheduler", event.scheduler)
5658
self.assertEqual("test_api", event.api)
5759
self.assertEqual("test_app_image", event.app_image)
@@ -76,6 +78,7 @@ def test_event_deser(self) -> None:
7678

7779

7880
@patch("torchx.runner.events.record")
81+
@patch("torchx.runner.events.get_session_id_or_create_new")
7982
class LogEventTest(unittest.TestCase):
8083
def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None:
8184
self.assertEqual(expected.session, actual.session)
@@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non
8689
self.assertEqual(expected.workspace, actual.workspace)
8790
self.assertEqual(expected.app_metadata, actual.app_metadata)
8891

89-
def test_create_context(self, _) -> None:
92+
def test_create_context(
93+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
94+
) -> None:
95+
get_session_id_or_create_new_mock.return_value = SESSION_ID
9096
test_dict = {"test_key": "test_value"}
9197
cfg = json.dumps(test_dict)
9298
context = log_event(
@@ -99,7 +105,7 @@ def test_create_context(self, _) -> None:
99105
workspace="test_workspace",
100106
)
101107
expected_torchx_event = TorchxEvent(
102-
"test_app_id",
108+
SESSION_ID,
103109
"local",
104110
"test_call",
105111
"test_app_id",
@@ -111,7 +117,10 @@ def test_create_context(self, _) -> None:
111117

112118
self.assert_torchx_event(expected_torchx_event, context._torchx_event)
113119

114-
def test_record_event(self, record_mock: MagicMock) -> None:
120+
def test_record_event(
121+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
122+
) -> None:
123+
get_session_id_or_create_new_mock.return_value = SESSION_ID
115124
test_dict = {"test_key": "test_value"}
116125
cfg = json.dumps(test_dict)
117126
with log_event(
@@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None:
126135
pass
127136

128137
expected_torchx_event = TorchxEvent(
129-
"test_app_id",
138+
SESSION_ID,
130139
"local",
131140
"test_call",
132141
"test_app_id",
@@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None:
139148
)
140149
self.assert_torchx_event(expected_torchx_event, ctx._torchx_event)
141150

142-
def test_record_event_with_exception(self, record_mock: MagicMock) -> None:
151+
def test_record_event_with_exception(
152+
self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock
153+
) -> None:
143154
cfg = json.dumps({"test_key": "test_value"})
144155
with self.assertRaises(RuntimeError):
145156
with log_event("test_call", "local", "test_app_id", cfg) as ctx:

torchx/runner/test/api_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torchx.specs.finder import ComponentNotFoundException
3333
from torchx.test.fixtures import TestWithTmpDir
3434
from torchx.tracker.api import ENV_TORCHX_JOB_ID, ENV_TORCHX_PARENT_RUN_ID
35+
from torchx.util.session import get_session_id
3536
from torchx.util.types import none_throws
3637
from torchx.workspace import WorkspaceMixin
3738

@@ -51,7 +52,7 @@ def get_full_path(name: str) -> str:
5152
return os.path.join(os.path.dirname(__file__), "resource", name)
5253

5354

54-
@patch("torchx.runner.api.log_event")
55+
@patch("torchx.runner.events.record")
5556
class RunnerTest(TestWithTmpDir):
5657
def setUp(self) -> None:
5758
super().setUp()
@@ -61,6 +62,13 @@ def setUp(self) -> None:
6162

6263
self.cfg = {}
6364

65+
self.uuid4_mock = patch("uuid.uuid4").start()
66+
self.uuid4_mock.return_value = "test_session_id"
67+
68+
def tearDown(self) -> None:
69+
self.uuid4_mock.stop()
70+
super().tearDown()
71+
6472
@contextmanager
6573
def get_runner(self) -> Generator[Runner, None, None]:
6674
with Runner(
@@ -104,6 +112,33 @@ def test_validate_invalid_replicas(self, _) -> None:
104112
with self.assertRaises(ValueError):
105113
runner.run(app, scheduler="local_dir")
106114

115+
def test_session_id(self, record_mock: MagicMock) -> None:
116+
test_file = self.tmpdir / "test_file"
117+
118+
with self.get_runner() as runner:
119+
self.assertEqual(1, len(runner.scheduler_backends()))
120+
role = Role(
121+
name="touch",
122+
image=str(self.tmpdir),
123+
resource=resource.SMALL,
124+
entrypoint="touch.sh",
125+
args=[str(test_file)],
126+
)
127+
app = AppDef("name", roles=[role])
128+
129+
app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
130+
none_throws(runner.wait(app_handle_1, wait_interval=0.1))
131+
132+
app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg)
133+
none_throws(runner.wait(app_handle_2, wait_interval=0.1))
134+
135+
self.assertEqual(get_session_id(), "test_session_id")
136+
self.uuid4_mock.assert_called_once()
137+
record_mock.assert_called()
138+
for i in range(record_mock.call_count):
139+
event = record_mock.call_args_list[i].args[0]
140+
self.assertEqual(event.session, "test_session_id")
141+
107142
def test_run(self, _) -> None:
108143
test_file = self.tmpdir / "test_file"
109144

@@ -121,6 +156,7 @@ def test_run(self, _) -> None:
121156
app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg)
122157
app_status = none_throws(runner.wait(app_handle, wait_interval=0.1))
123158
self.assertEqual(AppState.SUCCEEDED, app_status.state)
159+
self.assertEqual(get_session_id(), "test_session_id")
124160

125161
def test_dryrun(self, _) -> None:
126162
scheduler_mock = MagicMock()
@@ -145,6 +181,7 @@ def test_dryrun(self, _) -> None:
145181
app, {**self.cfg, "foo": "bar"}
146182
)
147183
scheduler_mock._validate.assert_called_once()
184+
self.assertEqual(get_session_id(), "test_session_id")
148185

149186
def test_dryrun_env_variables(self, _) -> None:
150187
scheduler_mock = MagicMock()

torchx/util/session.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import uuid
11+
from typing import Optional
12+
13+
_CURRENT_SESSION_ID: Optional[str] = None
14+
15+
16+
def get_session_id_or_create_new() -> str:
17+
global _CURRENT_SESSION_ID
18+
if _CURRENT_SESSION_ID:
19+
return _CURRENT_SESSION_ID
20+
session_id = str(uuid.uuid4())
21+
_CURRENT_SESSION_ID = session_id
22+
return session_id
23+
24+
25+
def get_session_id() -> Optional[str]:
26+
return _CURRENT_SESSION_ID

0 commit comments

Comments
 (0)