diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index 360cb3e7c..c8eb89d96 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -27,6 +27,7 @@ from typing import Dict, Optional, Type from torchx.runner.events.handlers import get_logging_handler +from torchx.util.session import get_session_id_or_create_new from .api import SourceType, TorchxEvent # noqa F401 @@ -136,7 +137,7 @@ def _generate_torchx_event( workspace: Optional[str] = None, ) -> TorchxEvent: return TorchxEvent( - session=app_id or "", + session=get_session_id_or_create_new(), scheduler=scheduler, api=api, app_id=app_id, diff --git a/torchx/runner/events/api.py b/torchx/runner/events/api.py index ce5bc8998..355c03f6c 100644 --- a/torchx/runner/events/api.py +++ b/torchx/runner/events/api.py @@ -25,7 +25,7 @@ class TorchxEvent: The class represents the event produced by ``torchx.runner`` api calls. Arguments: - session: Session id that was used to execute request. + session: Session id of the current run scheduler: Scheduler that is used to execute request api: Api name app_id: Unique id that is set by the underlying scheduler diff --git a/torchx/runner/events/test/lib_test.py b/torchx/runner/events/test/lib_test.py index 92bb3c828..bbeed590e 100644 --- a/torchx/runner/events/test/lib_test.py +++ b/torchx/runner/events/test/lib_test.py @@ -19,6 +19,8 @@ TorchxEvent, ) +SESSION_ID = "123" + class TorchxEventLibTest(unittest.TestCase): def assert_event( @@ -44,14 +46,14 @@ def test_get_or_create_logger(self, logging_handler_mock: MagicMock) -> None: def test_event_created(self) -> None: test_metadata = {"test_key": "test_value"} event = TorchxEvent( - session="test_session", + session=SESSION_ID, scheduler="test_scheduler", api="test_api", app_image="test_app_image", app_metadata=test_metadata, workspace="test_workspace", ) - self.assertEqual("test_session", event.session) + self.assertEqual(SESSION_ID, event.session) self.assertEqual("test_scheduler", event.scheduler) self.assertEqual("test_api", event.api) self.assertEqual("test_app_image", event.app_image) @@ -76,6 +78,7 @@ def test_event_deser(self) -> None: @patch("torchx.runner.events.record") +@patch("torchx.runner.events.get_session_id_or_create_new") class LogEventTest(unittest.TestCase): def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> None: self.assertEqual(expected.session, actual.session) @@ -86,7 +89,10 @@ def assert_torchx_event(self, expected: TorchxEvent, actual: TorchxEvent) -> Non self.assertEqual(expected.workspace, actual.workspace) self.assertEqual(expected.app_metadata, actual.app_metadata) - def test_create_context(self, _) -> None: + def test_create_context( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: + get_session_id_or_create_new_mock.return_value = SESSION_ID test_dict = {"test_key": "test_value"} cfg = json.dumps(test_dict) context = log_event( @@ -99,7 +105,7 @@ def test_create_context(self, _) -> None: workspace="test_workspace", ) expected_torchx_event = TorchxEvent( - "test_app_id", + SESSION_ID, "local", "test_call", "test_app_id", @@ -111,7 +117,10 @@ def test_create_context(self, _) -> None: self.assert_torchx_event(expected_torchx_event, context._torchx_event) - def test_record_event(self, record_mock: MagicMock) -> None: + def test_record_event( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: + get_session_id_or_create_new_mock.return_value = SESSION_ID test_dict = {"test_key": "test_value"} cfg = json.dumps(test_dict) with log_event( @@ -126,7 +135,7 @@ def test_record_event(self, record_mock: MagicMock) -> None: pass expected_torchx_event = TorchxEvent( - "test_app_id", + SESSION_ID, "local", "test_call", "test_app_id", @@ -139,7 +148,9 @@ def test_record_event(self, record_mock: MagicMock) -> None: ) self.assert_torchx_event(expected_torchx_event, ctx._torchx_event) - def test_record_event_with_exception(self, record_mock: MagicMock) -> None: + def test_record_event_with_exception( + self, get_session_id_or_create_new_mock: MagicMock, record_mock: MagicMock + ) -> None: cfg = json.dumps({"test_key": "test_value"}) with self.assertRaises(RuntimeError): with log_event("test_call", "local", "test_app_id", cfg) as ctx: diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index 155555afa..67182f38f 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -51,7 +51,7 @@ def get_full_path(name: str) -> str: return os.path.join(os.path.dirname(__file__), "resource", name) -@patch("torchx.runner.api.log_event") +@patch("torchx.runner.events.record") class RunnerTest(TestWithTmpDir): def setUp(self) -> None: super().setUp() @@ -104,7 +104,35 @@ def test_validate_invalid_replicas(self, _) -> None: with self.assertRaises(ValueError): runner.run(app, scheduler="local_dir") - def test_run(self, _) -> None: + def test_session_id(self, record_mock: MagicMock) -> None: + test_file = self.tmpdir / "test_file" + + with self.get_runner() as runner: + self.assertEqual(1, len(runner.scheduler_backends())) + role = Role( + name="touch", + image=str(self.tmpdir), + resource=resource.SMALL, + entrypoint="touch.sh", + args=[str(test_file)], + ) + app = AppDef("name", roles=[role]) + + app_handle_1 = runner.run(app, scheduler="local_dir", cfg=self.cfg) + none_throws(runner.wait(app_handle_1, wait_interval=0.1)) + + app_handle_2 = runner.run(app, scheduler="local_dir", cfg=self.cfg) + none_throws(runner.wait(app_handle_2, wait_interval=0.1)) + + from torchx.util.session import CURRENT_SESSION_ID + + self.assertIsNotNone(CURRENT_SESSION_ID) + record_mock.assert_called() + for i in range(record_mock.call_count): + event = record_mock.call_args_list[i].args[0] + self.assertEqual(event.session, CURRENT_SESSION_ID) + + def test_run(self, record_mock: MagicMock) -> None: test_file = self.tmpdir / "test_file" with self.get_runner() as runner: @@ -121,8 +149,15 @@ def test_run(self, _) -> None: app_handle = runner.run(app, scheduler="local_dir", cfg=self.cfg) app_status = none_throws(runner.wait(app_handle, wait_interval=0.1)) self.assertEqual(AppState.SUCCEEDED, app_status.state) + from torchx.util.session import CURRENT_SESSION_ID - def test_dryrun(self, _) -> None: + self.assertIsNotNone(CURRENT_SESSION_ID) + record_mock.assert_called() + for i in range(record_mock.call_count): + event = record_mock.call_args_list[i].args[0] + self.assertEqual(event.session, CURRENT_SESSION_ID) + + def test_dryrun(self, record_mock: MagicMock) -> None: scheduler_mock = MagicMock() scheduler_mock.run_opts.return_value.resolve.return_value = { **self.cfg, @@ -145,6 +180,13 @@ def test_dryrun(self, _) -> None: app, {**self.cfg, "foo": "bar"} ) scheduler_mock._validate.assert_called_once() + from torchx.util.session import CURRENT_SESSION_ID + + self.assertIsNotNone(CURRENT_SESSION_ID) + record_mock.assert_called() + for i in range(record_mock.call_count): + event = record_mock.call_args_list[i].args[0] + self.assertEqual(event.session, CURRENT_SESSION_ID) def test_dryrun_env_variables(self, _) -> None: scheduler_mock = MagicMock() diff --git a/torchx/util/session.py b/torchx/util/session.py new file mode 100644 index 000000000..80774fab7 --- /dev/null +++ b/torchx/util/session.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import uuid +from typing import Optional + +CURRENT_SESSION_ID: Optional[str] = None + + +def get_session_id_or_create_new() -> str: + """ + Returns the current session ID, or creates a new one if none exists. + The session ID remains the same as long as it is in the same process. + """ + global CURRENT_SESSION_ID + if CURRENT_SESSION_ID: + return CURRENT_SESSION_ID + session_id = str(uuid.uuid4()) + CURRENT_SESSION_ID = session_id + return session_id