Skip to content

Commit 5603572

Browse files
manav-afacebook-github-bot
authored andcommitted
Allow using torchx_ env vars to set scheduler params (#915)
Summary: Scheduler params currently can only be set through the programmatic API and not through this is not useful for cases like scheduling on mast rc cluster. This diff now lets you do that. Differential Revision: D57640022
1 parent 05ddf23 commit 5603572

File tree

3 files changed

+30
-23
lines changed

3 files changed

+30
-23
lines changed

torchx/runner/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def __init__(
9898
"""
9999
self._name: str = name
100100
self._scheduler_factories = scheduler_factories
101-
self._scheduler_params: Dict[str, object] = scheduler_params or {}
101+
self._scheduler_params: Dict[str, Any] = scheduler_params or {}
102+
for key, value in os.environ.items():
103+
lower_case_key = key.lower()
104+
if lower_case_key.startswith("torchx_"):
105+
self._scheduler_params[lower_case_key] = value
106+
102107
# pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
103108
self._scheduler_instances: Dict[str, Scheduler] = {}
104109
self._apps: Dict[AppHandle, AppDef] = {}

torchx/runner/test/api_test.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
import datetime
1111
import os
1212
from contextlib import contextmanager
13-
from typing import Generator, List, Mapping, Optional
13+
from typing import cast, Generator, List, Mapping, Optional
1414
from unittest.mock import MagicMock, patch
1515

1616
from torchx.runner import get_runner, Runner
17+
from torchx.schedulers import SchedulerFactory
1718
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
1819
from torchx.schedulers.local_scheduler import (
20+
create_scheduler,
1921
LocalDirectoryImageProvider,
20-
LocalScheduler,
2122
)
2223
from torchx.specs import AppDryRunInfo, CfgVal
2324
from torchx.specs.api import (
@@ -64,7 +65,7 @@ def setUp(self) -> None:
6465
def get_runner(self) -> Generator[Runner, None, None]:
6566
with Runner(
6667
SESSION_NAME,
67-
scheduler_factories={"local_dir": LocalScheduler},
68+
scheduler_factories={"local_dir": cast(SchedulerFactory, create_scheduler)},
6869
scheduler_params={
6970
"image_provider_class": LocalDirectoryImageProvider,
7071
},
@@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:
7980

8081
def test_validate_no_resource(self, _) -> None:
8182
with self.get_runner() as runner:
83+
role = Role(
84+
"no resource",
85+
image="no_image",
86+
entrypoint="echo",
87+
args=["hello_world"],
88+
)
89+
app = AppDef("no resource", roles=[role])
8290
with self.assertRaises(ValueError):
83-
role = Role(
84-
"no resource",
85-
image="no_image",
86-
entrypoint="echo",
87-
args=["hello_world"],
88-
)
89-
app = AppDef("no resource", roles=[role])
9091
runner.run(app, scheduler="local_dir")
9192

9293
def test_validate_invalid_replicas(self, _) -> None:
@@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
129130
}
130131
with Runner(
131132
name=SESSION_NAME,
132-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
133+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
133134
) as runner:
134135
role = Role(
135136
name="touch",
@@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
149150
scheduler_mock = MagicMock()
150151
with Runner(
151152
name=SESSION_NAME,
152-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
153+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
153154
) as runner:
154155
role1 = Role(
155156
name="echo1",
@@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
178179
expected_parent_run_id = "123"
179180
with Runner(
180181
name=SESSION_NAME,
181-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
182+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
182183
) as runner:
183184
role1 = Role(
184185
name="echo1",
@@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None
217218

218219
with Runner(
219220
name=SESSION_NAME,
220-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
221+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
221222
) as runner:
222223
role1 = Role(
223224
name="echo1",
@@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
265266

266267
with Runner(
267268
name=SESSION_NAME,
268-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
269+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
269270
) as runner:
270271
role1 = Role(
271272
name="echo1",
@@ -371,7 +372,7 @@ def test_describe(self, _) -> None:
371372
name="sleep",
372373
image=str(self.tmpdir),
373374
resource=resource.SMALL,
374-
entrypoint="sleep.sh",
375+
entrypoint="sleep",
375376
args=["60"],
376377
)
377378
app = AppDef("sleeper", roles=[role])
@@ -387,7 +388,7 @@ def test_status(self, _) -> None:
387388
name="sleep",
388389
image=str(self.tmpdir),
389390
resource=resource.SMALL,
390-
entrypoint="sleep.sh",
391+
entrypoint="sleep",
391392
args=["60"],
392393
)
393394
app = AppDef("sleeper", roles=[role])
@@ -414,7 +415,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
414415

415416
with Runner(
416417
name="test_ui_url_session",
417-
scheduler_factories={"local_dir": lambda name: mock_scheduler},
418+
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
418419
) as runner:
419420
role = Role(
420421
"ignored",
@@ -438,7 +439,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
438439

439440
with Runner(
440441
name="test_structured_msg",
441-
scheduler_factories={"local_dir": lambda name: mock_scheduler},
442+
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
442443
) as runner:
443444
role = Role(
444445
"ignored",
@@ -485,7 +486,7 @@ def test_log_lines(self, _) -> None:
485486

486487
with Runner(
487488
name=SESSION_NAME,
488-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
489+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
489490
) as runner:
490491
role_name = "trainer"
491492
replica_id = 2
@@ -605,7 +606,7 @@ def test_runner_context_manager_with_error(self, _) -> None:
605606
return_value={"local_dir": lambda name: mock_scheduler},
606607
):
607608
with self.assertRaisesRegex(RuntimeError, "foobar"):
608-
with get_runner() as runner:
609+
with get_runner():
609610
raise RuntimeError("foobar")
610611

611612
def test_runner_manual_close(self, _) -> None:

torchx/schedulers/local_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,11 +1184,12 @@ def create_scheduler(
11841184
session_name: str,
11851185
cache_size: int = 100,
11861186
extra_paths: Optional[List[str]] = None,
1187+
image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider,
11871188
**kwargs: Any,
11881189
) -> LocalScheduler:
11891190
return LocalScheduler(
11901191
session_name=session_name,
1191-
image_provider_class=CWDImageProvider,
1192+
image_provider_class=image_provider_class,
11921193
cache_size=cache_size,
11931194
extra_paths=extra_paths,
11941195
)

0 commit comments

Comments
 (0)