Skip to content

Commit d3393fc

Browse files
authored
Allow using torchx_ env vars to set scheduler params (#915) (#915)
Summary: Scheduler params currently can only be set through the programmatic API and not through this is not useful for cases like scheduling on mast rc cluster. This diff now lets you do that. Reviewed By: andywag Differential Revision: D57640022
1 parent 05ddf23 commit d3393fc

File tree

3 files changed

+48
-33
lines changed

3 files changed

+48
-33
lines changed

torchx/runner/api.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,25 @@ def __init__(
9898
"""
9999
self._name: str = name
100100
self._scheduler_factories = scheduler_factories
101-
self._scheduler_params: Dict[str, object] = scheduler_params or {}
101+
self._scheduler_params: Dict[str, Any] = {
102+
**(self._get_scheduler_params_from_env()),
103+
**(scheduler_params or {}),
104+
}
102105
# pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
103106
self._scheduler_instances: Dict[str, Scheduler] = {}
104107
self._apps: Dict[AppHandle, AppDef] = {}
105108

106109
# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
107110
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}
108111

112+
def _get_scheduler_params_from_env(self) -> Dict[str, str]:
113+
scheduler_params = {}
114+
for key, value in os.environ.items():
115+
lower_case_key = key.lower()
116+
if lower_case_key.startswith("torchx_"):
117+
scheduler_params[lower_case_key.strip("torchx_")] = value
118+
return scheduler_params
119+
109120
def __enter__(self) -> "Runner":
110121
return self
111122

torchx/runner/test/api_test.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
import datetime
1111
import os
1212
from contextlib import contextmanager
13-
from typing import Generator, List, Mapping, Optional
13+
from typing import cast, Generator, List, Mapping, Optional
1414
from unittest.mock import MagicMock, patch
1515

1616
from torchx.runner import get_runner, Runner
17+
from torchx.schedulers import SchedulerFactory
1718
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
1819
from torchx.schedulers.local_scheduler import (
20+
create_scheduler,
1921
LocalDirectoryImageProvider,
20-
LocalScheduler,
2122
)
2223
from torchx.specs import AppDryRunInfo, CfgVal
2324
from torchx.specs.api import (
@@ -64,7 +65,7 @@ def setUp(self) -> None:
6465
def get_runner(self) -> Generator[Runner, None, None]:
6566
with Runner(
6667
SESSION_NAME,
67-
scheduler_factories={"local_dir": LocalScheduler},
68+
scheduler_factories={"local_dir": cast(SchedulerFactory, create_scheduler)},
6869
scheduler_params={
6970
"image_provider_class": LocalDirectoryImageProvider,
7071
},
@@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:
7980

8081
def test_validate_no_resource(self, _) -> None:
8182
with self.get_runner() as runner:
83+
role = Role(
84+
"no resource",
85+
image="no_image",
86+
entrypoint="echo",
87+
args=["hello_world"],
88+
)
89+
app = AppDef("no resource", roles=[role])
8290
with self.assertRaises(ValueError):
83-
role = Role(
84-
"no resource",
85-
image="no_image",
86-
entrypoint="echo",
87-
args=["hello_world"],
88-
)
89-
app = AppDef("no resource", roles=[role])
9091
runner.run(app, scheduler="local_dir")
9192

9293
def test_validate_invalid_replicas(self, _) -> None:
@@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
129130
}
130131
with Runner(
131132
name=SESSION_NAME,
132-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
133+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
133134
) as runner:
134135
role = Role(
135136
name="touch",
@@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
149150
scheduler_mock = MagicMock()
150151
with Runner(
151152
name=SESSION_NAME,
152-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
153+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
153154
) as runner:
154155
role1 = Role(
155156
name="echo1",
@@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
178179
expected_parent_run_id = "123"
179180
with Runner(
180181
name=SESSION_NAME,
181-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
182+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
182183
) as runner:
183184
role1 = Role(
184185
name="echo1",
@@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None
217218

218219
with Runner(
219220
name=SESSION_NAME,
220-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
221+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
221222
) as runner:
222223
role1 = Role(
223224
name="echo1",
@@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
265266

266267
with Runner(
267268
name=SESSION_NAME,
268-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
269+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
269270
) as runner:
270271
role1 = Role(
271272
name="echo1",
@@ -333,8 +334,10 @@ def build_workspace_and_update_role(
333334
name=SESSION_NAME,
334335
# pyre-fixme[6]: scheduler factory type
335336
scheduler_factories={
336-
"no-build-img": lambda name: TestScheduler(build_new_img=False),
337-
"builds-img": lambda name: TestScheduler(build_new_img=True),
337+
"no-build-img": lambda name, **kwargs: TestScheduler(
338+
build_new_img=False
339+
),
340+
"builds-img": lambda name, **kwargs: TestScheduler(build_new_img=True),
338341
},
339342
) as runner:
340343
app = AppDef(
@@ -371,7 +374,7 @@ def test_describe(self, _) -> None:
371374
name="sleep",
372375
image=str(self.tmpdir),
373376
resource=resource.SMALL,
374-
entrypoint="sleep.sh",
377+
entrypoint="sleep",
375378
args=["60"],
376379
)
377380
app = AppDef("sleeper", roles=[role])
@@ -387,7 +390,7 @@ def test_status(self, _) -> None:
387390
name="sleep",
388391
image=str(self.tmpdir),
389392
resource=resource.SMALL,
390-
entrypoint="sleep.sh",
393+
entrypoint="sleep",
391394
args=["60"],
392395
)
393396
app = AppDef("sleeper", roles=[role])
@@ -414,7 +417,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
414417

415418
with Runner(
416419
name="test_ui_url_session",
417-
scheduler_factories={"local_dir": lambda name: mock_scheduler},
420+
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
418421
) as runner:
419422
role = Role(
420423
"ignored",
@@ -438,7 +441,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
438441

439442
with Runner(
440443
name="test_structured_msg",
441-
scheduler_factories={"local_dir": lambda name: mock_scheduler},
444+
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
442445
) as runner:
443446
role = Role(
444447
"ignored",
@@ -485,7 +488,7 @@ def test_log_lines(self, _) -> None:
485488

486489
with Runner(
487490
name=SESSION_NAME,
488-
scheduler_factories={"local_dir": lambda name: scheduler_mock},
491+
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
489492
) as runner:
490493
role_name = "trainer"
491494
replica_id = 2
@@ -529,7 +532,7 @@ def test_list(self, _) -> None:
529532
]
530533
with Runner(
531534
name=SESSION_NAME,
532-
scheduler_factories={"kubernetes": lambda name: scheduler_mock},
535+
scheduler_factories={"kubernetes": lambda name, **kwargs: scheduler_mock},
533536
) as runner:
534537
apps = runner.list("kubernetes")
535538
self.assertEqual(apps, apps_expected)
@@ -541,8 +544,8 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
541544
json_dumps_mock.return_value = "{}"
542545
local_sched_mock = MagicMock()
543546
scheduler_factories = {
544-
"local_dir": lambda name: local_dir_sched_mock,
545-
"local": lambda name: local_sched_mock,
547+
"local_dir": lambda name, **kwargs: local_dir_sched_mock,
548+
"local": lambda name, **kwargs: local_sched_mock,
546549
}
547550
with Runner(
548551
name="test_session", scheduler_factories=scheduler_factories
@@ -576,8 +579,8 @@ def test_run_from_module(self, _: str) -> None:
576579
def test_run_from_file_no_function_found(self, _) -> None:
577580
local_sched_mock = MagicMock()
578581
schedulers = {
579-
"local_dir": lambda name: local_sched_mock,
580-
"local": lambda name: local_sched_mock,
582+
"local_dir": lambda name, **kwargs: local_sched_mock,
583+
"local": lambda name, **kwargs: local_sched_mock,
581584
}
582585
with Runner(name="test_session", scheduler_factories=schedulers) as runner:
583586
component_path = get_full_path("distributed.py")
@@ -591,7 +594,7 @@ def test_runner_context_manager(self, _) -> None:
591594
mock_scheduler = MagicMock()
592595
with patch(
593596
GET_SCHEDULER_FACTORIES,
594-
return_value={"local_dir": lambda name: mock_scheduler},
597+
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
595598
):
596599
with get_runner() as runner:
597600
# force schedulers to load
@@ -602,17 +605,17 @@ def test_runner_context_manager_with_error(self, _) -> None:
602605
mock_scheduler = MagicMock()
603606
with patch(
604607
GET_SCHEDULER_FACTORIES,
605-
return_value={"local_dir": lambda name: mock_scheduler},
608+
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
606609
):
607610
with self.assertRaisesRegex(RuntimeError, "foobar"):
608-
with get_runner() as runner:
611+
with get_runner():
609612
raise RuntimeError("foobar")
610613

611614
def test_runner_manual_close(self, _) -> None:
612615
mock_scheduler = MagicMock()
613616
with patch(
614617
GET_SCHEDULER_FACTORIES,
615-
return_value={"local_dir": lambda name: mock_scheduler},
618+
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
616619
):
617620
runner = get_runner()
618621
# force schedulers to load

torchx/schedulers/local_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,11 +1184,12 @@ def create_scheduler(
11841184
session_name: str,
11851185
cache_size: int = 100,
11861186
extra_paths: Optional[List[str]] = None,
1187+
image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider,
11871188
**kwargs: Any,
11881189
) -> LocalScheduler:
11891190
return LocalScheduler(
11901191
session_name=session_name,
1191-
image_provider_class=CWDImageProvider,
1192+
image_provider_class=image_provider_class,
11921193
cache_size=cache_size,
11931194
extra_paths=extra_paths,
11941195
)

0 commit comments

Comments
 (0)