diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 4641e3899..ea7968db2 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -98,7 +98,10 @@ def __init__( """ self._name: str = name self._scheduler_factories = scheduler_factories - self._scheduler_params: Dict[str, object] = scheduler_params or {} + self._scheduler_params: Dict[str, Any] = { + **(self._get_scheduler_params_from_env()), + **(scheduler_params or {}), + } # pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type self._scheduler_instances: Dict[str, Scheduler] = {} self._apps: Dict[AppHandle, AppDef] = {} @@ -106,6 +109,14 @@ def __init__( # component_name -> map of component_fn_param_name -> user-specified default val encoded as str self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {} + def _get_scheduler_params_from_env(self) -> Dict[str, str]: + scheduler_params = {} + for key, value in os.environ.items(): + lower_case_key = key.lower() + if lower_case_key.startswith("torchx_"): + scheduler_params[lower_case_key.strip("torchx_")] = value + return scheduler_params + def __enter__(self) -> "Runner": return self diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index d7099321a..155555afa 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -10,14 +10,15 @@ import datetime import os from contextlib import contextmanager -from typing import Generator, List, Mapping, Optional +from typing import cast, Generator, List, Mapping, Optional from unittest.mock import MagicMock, patch from torchx.runner import get_runner, Runner +from torchx.schedulers import SchedulerFactory from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler from torchx.schedulers.local_scheduler import ( + create_scheduler, LocalDirectoryImageProvider, - LocalScheduler, ) from torchx.specs import AppDryRunInfo, CfgVal from torchx.specs.api import ( @@ -64,7 +65,7 @@ def setUp(self) -> None: def get_runner(self) -> Generator[Runner, None, None]: with Runner( SESSION_NAME, - scheduler_factories={"local_dir": LocalScheduler}, + scheduler_factories={"local_dir": cast(SchedulerFactory, create_scheduler)}, scheduler_params={ "image_provider_class": LocalDirectoryImageProvider, }, @@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None: def test_validate_no_resource(self, _) -> None: with self.get_runner() as runner: + role = Role( + "no resource", + image="no_image", + entrypoint="echo", + args=["hello_world"], + ) + app = AppDef("no resource", roles=[role]) with self.assertRaises(ValueError): - role = Role( - "no resource", - image="no_image", - entrypoint="echo", - args=["hello_world"], - ) - app = AppDef("no resource", roles=[role]) runner.run(app, scheduler="local_dir") def test_validate_invalid_replicas(self, _) -> None: @@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None: } with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role = Role( name="touch", @@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None: scheduler_mock = MagicMock() with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role1 = Role( name="echo1", @@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None: expected_parent_run_id = "123" with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role1 = Role( name="echo1", @@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role1 = Role( name="echo1", @@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None: with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role1 = Role( name="echo1", @@ -333,8 +334,10 @@ def build_workspace_and_update_role( name=SESSION_NAME, # pyre-fixme[6]: scheduler factory type scheduler_factories={ - "no-build-img": lambda name: TestScheduler(build_new_img=False), - "builds-img": lambda name: TestScheduler(build_new_img=True), + "no-build-img": lambda name, **kwargs: TestScheduler( + build_new_img=False + ), + "builds-img": lambda name, **kwargs: TestScheduler(build_new_img=True), }, ) as runner: app = AppDef( @@ -371,7 +374,7 @@ def test_describe(self, _) -> None: name="sleep", image=str(self.tmpdir), resource=resource.SMALL, - entrypoint="sleep.sh", + entrypoint="sleep", args=["60"], ) app = AppDef("sleeper", roles=[role]) @@ -387,7 +390,7 @@ def test_status(self, _) -> None: name="sleep", image=str(self.tmpdir), resource=resource.SMALL, - entrypoint="sleep.sh", + entrypoint="sleep", args=["60"], ) app = AppDef("sleeper", roles=[role]) @@ -414,7 +417,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None: with Runner( name="test_ui_url_session", - scheduler_factories={"local_dir": lambda name: mock_scheduler}, + scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler}, ) as runner: role = Role( "ignored", @@ -438,7 +441,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None: with Runner( name="test_structured_msg", - scheduler_factories={"local_dir": lambda name: mock_scheduler}, + scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler}, ) as runner: role = Role( "ignored", @@ -485,7 +488,7 @@ def test_log_lines(self, _) -> None: with Runner( name=SESSION_NAME, - scheduler_factories={"local_dir": lambda name: scheduler_mock}, + scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock}, ) as runner: role_name = "trainer" replica_id = 2 @@ -529,7 +532,7 @@ def test_list(self, _) -> None: ] with Runner( name=SESSION_NAME, - scheduler_factories={"kubernetes": lambda name: scheduler_mock}, + scheduler_factories={"kubernetes": lambda name, **kwargs: scheduler_mock}, ) as runner: apps = runner.list("kubernetes") self.assertEqual(apps, apps_expected) @@ -541,8 +544,8 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None: json_dumps_mock.return_value = "{}" local_sched_mock = MagicMock() scheduler_factories = { - "local_dir": lambda name: local_dir_sched_mock, - "local": lambda name: local_sched_mock, + "local_dir": lambda name, **kwargs: local_dir_sched_mock, + "local": lambda name, **kwargs: local_sched_mock, } with Runner( name="test_session", scheduler_factories=scheduler_factories @@ -576,8 +579,8 @@ def test_run_from_module(self, _: str) -> None: def test_run_from_file_no_function_found(self, _) -> None: local_sched_mock = MagicMock() schedulers = { - "local_dir": lambda name: local_sched_mock, - "local": lambda name: local_sched_mock, + "local_dir": lambda name, **kwargs: local_sched_mock, + "local": lambda name, **kwargs: local_sched_mock, } with Runner(name="test_session", scheduler_factories=schedulers) as runner: component_path = get_full_path("distributed.py") @@ -591,7 +594,7 @@ def test_runner_context_manager(self, _) -> None: mock_scheduler = MagicMock() with patch( GET_SCHEDULER_FACTORIES, - return_value={"local_dir": lambda name: mock_scheduler}, + return_value={"local_dir": lambda name, **kwargs: mock_scheduler}, ): with get_runner() as runner: # force schedulers to load @@ -602,17 +605,17 @@ def test_runner_context_manager_with_error(self, _) -> None: mock_scheduler = MagicMock() with patch( GET_SCHEDULER_FACTORIES, - return_value={"local_dir": lambda name: mock_scheduler}, + return_value={"local_dir": lambda name, **kwargs: mock_scheduler}, ): with self.assertRaisesRegex(RuntimeError, "foobar"): - with get_runner() as runner: + with get_runner(): raise RuntimeError("foobar") def test_runner_manual_close(self, _) -> None: mock_scheduler = MagicMock() with patch( GET_SCHEDULER_FACTORIES, - return_value={"local_dir": lambda name: mock_scheduler}, + return_value={"local_dir": lambda name, **kwargs: mock_scheduler}, ): runner = get_runner() # force schedulers to load diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index 2fa07d362..9390902c4 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -1184,11 +1184,12 @@ def create_scheduler( session_name: str, cache_size: int = 100, extra_paths: Optional[List[str]] = None, + image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider, **kwargs: Any, ) -> LocalScheduler: return LocalScheduler( session_name=session_name, - image_provider_class=CWDImageProvider, + image_provider_class=image_provider_class, cache_size=cache_size, extra_paths=extra_paths, )