Skip to content

Allow using torchx_ env vars to set scheduler params #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,25 @@ def __init__(
"""
self._name: str = name
self._scheduler_factories = scheduler_factories
self._scheduler_params: Dict[str, object] = scheduler_params or {}
self._scheduler_params: Dict[str, Any] = {
**(self._get_scheduler_params_from_env()),
**(scheduler_params or {}),
}
# pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
self._scheduler_instances: Dict[str, Scheduler] = {}
self._apps: Dict[AppHandle, AppDef] = {}

# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}

def _get_scheduler_params_from_env(self) -> Dict[str, str]:
scheduler_params = {}
for key, value in os.environ.items():
lower_case_key = key.lower()
if lower_case_key.startswith("torchx_"):
scheduler_params[lower_case_key.strip("torchx_")] = value
return scheduler_params

def __enter__(self) -> "Runner":
return self

Expand Down
65 changes: 34 additions & 31 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import datetime
import os
from contextlib import contextmanager
from typing import Generator, List, Mapping, Optional
from typing import cast, Generator, List, Mapping, Optional
from unittest.mock import MagicMock, patch

from torchx.runner import get_runner, Runner
from torchx.schedulers import SchedulerFactory
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
from torchx.schedulers.local_scheduler import (
create_scheduler,
LocalDirectoryImageProvider,
LocalScheduler,
)
from torchx.specs import AppDryRunInfo, CfgVal
from torchx.specs.api import (
Expand Down Expand Up @@ -64,7 +65,7 @@ def setUp(self) -> None:
def get_runner(self) -> Generator[Runner, None, None]:
with Runner(
SESSION_NAME,
scheduler_factories={"local_dir": LocalScheduler},
scheduler_factories={"local_dir": cast(SchedulerFactory, create_scheduler)},
scheduler_params={
"image_provider_class": LocalDirectoryImageProvider,
},
Expand All @@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:

def test_validate_no_resource(self, _) -> None:
with self.get_runner() as runner:
role = Role(
"no resource",
image="no_image",
entrypoint="echo",
args=["hello_world"],
)
app = AppDef("no resource", roles=[role])
with self.assertRaises(ValueError):
role = Role(
"no resource",
image="no_image",
entrypoint="echo",
args=["hello_world"],
)
app = AppDef("no resource", roles=[role])
runner.run(app, scheduler="local_dir")

def test_validate_invalid_replicas(self, _) -> None:
Expand Down Expand Up @@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
}
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role = Role(
name="touch",
Expand All @@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
scheduler_mock = MagicMock()
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
expected_parent_run_id = "123"
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -333,8 +334,10 @@ def build_workspace_and_update_role(
name=SESSION_NAME,
# pyre-fixme[6]: scheduler factory type
scheduler_factories={
"no-build-img": lambda name: TestScheduler(build_new_img=False),
"builds-img": lambda name: TestScheduler(build_new_img=True),
"no-build-img": lambda name, **kwargs: TestScheduler(
build_new_img=False
),
"builds-img": lambda name, **kwargs: TestScheduler(build_new_img=True),
},
) as runner:
app = AppDef(
Expand Down Expand Up @@ -371,7 +374,7 @@ def test_describe(self, _) -> None:
name="sleep",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="sleep.sh",
entrypoint="sleep",
args=["60"],
)
app = AppDef("sleeper", roles=[role])
Expand All @@ -387,7 +390,7 @@ def test_status(self, _) -> None:
name="sleep",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="sleep.sh",
entrypoint="sleep",
args=["60"],
)
app = AppDef("sleeper", roles=[role])
Expand All @@ -414,7 +417,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:

with Runner(
name="test_ui_url_session",
scheduler_factories={"local_dir": lambda name: mock_scheduler},
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
) as runner:
role = Role(
"ignored",
Expand All @@ -438,7 +441,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:

with Runner(
name="test_structured_msg",
scheduler_factories={"local_dir": lambda name: mock_scheduler},
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
) as runner:
role = Role(
"ignored",
Expand Down Expand Up @@ -485,7 +488,7 @@ def test_log_lines(self, _) -> None:

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role_name = "trainer"
replica_id = 2
Expand Down Expand Up @@ -529,7 +532,7 @@ def test_list(self, _) -> None:
]
with Runner(
name=SESSION_NAME,
scheduler_factories={"kubernetes": lambda name: scheduler_mock},
scheduler_factories={"kubernetes": lambda name, **kwargs: scheduler_mock},
) as runner:
apps = runner.list("kubernetes")
self.assertEqual(apps, apps_expected)
Expand All @@ -541,8 +544,8 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
json_dumps_mock.return_value = "{}"
local_sched_mock = MagicMock()
scheduler_factories = {
"local_dir": lambda name: local_dir_sched_mock,
"local": lambda name: local_sched_mock,
"local_dir": lambda name, **kwargs: local_dir_sched_mock,
"local": lambda name, **kwargs: local_sched_mock,
}
with Runner(
name="test_session", scheduler_factories=scheduler_factories
Expand Down Expand Up @@ -576,8 +579,8 @@ def test_run_from_module(self, _: str) -> None:
def test_run_from_file_no_function_found(self, _) -> None:
local_sched_mock = MagicMock()
schedulers = {
"local_dir": lambda name: local_sched_mock,
"local": lambda name: local_sched_mock,
"local_dir": lambda name, **kwargs: local_sched_mock,
"local": lambda name, **kwargs: local_sched_mock,
}
with Runner(name="test_session", scheduler_factories=schedulers) as runner:
component_path = get_full_path("distributed.py")
Expand All @@ -591,7 +594,7 @@ def test_runner_context_manager(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
with get_runner() as runner:
# force schedulers to load
Expand All @@ -602,17 +605,17 @@ def test_runner_context_manager_with_error(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
with self.assertRaisesRegex(RuntimeError, "foobar"):
with get_runner() as runner:
with get_runner():
raise RuntimeError("foobar")

def test_runner_manual_close(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
runner = get_runner()
# force schedulers to load
Expand Down
3 changes: 2 additions & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,12 @@ def create_scheduler(
session_name: str,
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider,
**kwargs: Any,
) -> LocalScheduler:
return LocalScheduler(
session_name=session_name,
image_provider_class=CWDImageProvider,
image_provider_class=image_provider_class,
cache_size=cache_size,
extra_paths=extra_paths,
)
Loading