10
10
import datetime
11
11
import os
12
12
from contextlib import contextmanager
13
- from typing import Generator , List , Mapping , Optional
13
+ from typing import cast , Generator , List , Mapping , Optional
14
14
from unittest .mock import MagicMock , patch
15
15
16
16
from torchx .runner import get_runner , Runner
17
+ from torchx .schedulers import SchedulerFactory
17
18
from torchx .schedulers .api import DescribeAppResponse , ListAppResponse , Scheduler
18
19
from torchx .schedulers .local_scheduler import (
20
+ create_scheduler ,
19
21
LocalDirectoryImageProvider ,
20
- LocalScheduler ,
21
22
)
22
23
from torchx .specs import AppDryRunInfo , CfgVal
23
24
from torchx .specs .api import (
@@ -64,7 +65,7 @@ def setUp(self) -> None:
64
65
def get_runner (self ) -> Generator [Runner , None , None ]:
65
66
with Runner (
66
67
SESSION_NAME ,
67
- scheduler_factories = {"local_dir" : LocalScheduler },
68
+ scheduler_factories = {"local_dir" : cast ( SchedulerFactory , create_scheduler ) },
68
69
scheduler_params = {
69
70
"image_provider_class" : LocalDirectoryImageProvider ,
70
71
},
@@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:
79
80
80
81
def test_validate_no_resource (self , _ ) -> None :
81
82
with self .get_runner () as runner :
83
+ role = Role (
84
+ "no resource" ,
85
+ image = "no_image" ,
86
+ entrypoint = "echo" ,
87
+ args = ["hello_world" ],
88
+ )
89
+ app = AppDef ("no resource" , roles = [role ])
82
90
with self .assertRaises (ValueError ):
83
- role = Role (
84
- "no resource" ,
85
- image = "no_image" ,
86
- entrypoint = "echo" ,
87
- args = ["hello_world" ],
88
- )
89
- app = AppDef ("no resource" , roles = [role ])
90
91
runner .run (app , scheduler = "local_dir" )
91
92
92
93
def test_validate_invalid_replicas (self , _ ) -> None :
@@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
129
130
}
130
131
with Runner (
131
132
name = SESSION_NAME ,
132
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
133
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
133
134
) as runner :
134
135
role = Role (
135
136
name = "touch" ,
@@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
149
150
scheduler_mock = MagicMock ()
150
151
with Runner (
151
152
name = SESSION_NAME ,
152
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
153
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
153
154
) as runner :
154
155
role1 = Role (
155
156
name = "echo1" ,
@@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
178
179
expected_parent_run_id = "123"
179
180
with Runner (
180
181
name = SESSION_NAME ,
181
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
182
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
182
183
) as runner :
183
184
role1 = Role (
184
185
name = "echo1" ,
@@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None
217
218
218
219
with Runner (
219
220
name = SESSION_NAME ,
220
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
221
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
221
222
) as runner :
222
223
role1 = Role (
223
224
name = "echo1" ,
@@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
265
266
266
267
with Runner (
267
268
name = SESSION_NAME ,
268
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
269
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
269
270
) as runner :
270
271
role1 = Role (
271
272
name = "echo1" ,
@@ -371,7 +372,7 @@ def test_describe(self, _) -> None:
371
372
name = "sleep" ,
372
373
image = str (self .tmpdir ),
373
374
resource = resource .SMALL ,
374
- entrypoint = "sleep.sh " ,
375
+ entrypoint = "sleep" ,
375
376
args = ["60" ],
376
377
)
377
378
app = AppDef ("sleeper" , roles = [role ])
@@ -387,7 +388,7 @@ def test_status(self, _) -> None:
387
388
name = "sleep" ,
388
389
image = str (self .tmpdir ),
389
390
resource = resource .SMALL ,
390
- entrypoint = "sleep.sh " ,
391
+ entrypoint = "sleep" ,
391
392
args = ["60" ],
392
393
)
393
394
app = AppDef ("sleeper" , roles = [role ])
@@ -414,7 +415,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
414
415
415
416
with Runner (
416
417
name = "test_ui_url_session" ,
417
- scheduler_factories = {"local_dir" : lambda name : mock_scheduler },
418
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
418
419
) as runner :
419
420
role = Role (
420
421
"ignored" ,
@@ -438,7 +439,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
438
439
439
440
with Runner (
440
441
name = "test_structured_msg" ,
441
- scheduler_factories = {"local_dir" : lambda name : mock_scheduler },
442
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
442
443
) as runner :
443
444
role = Role (
444
445
"ignored" ,
@@ -485,7 +486,7 @@ def test_log_lines(self, _) -> None:
485
486
486
487
with Runner (
487
488
name = SESSION_NAME ,
488
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
489
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
489
490
) as runner :
490
491
role_name = "trainer"
491
492
replica_id = 2
@@ -605,7 +606,7 @@ def test_runner_context_manager_with_error(self, _) -> None:
605
606
return_value = {"local_dir" : lambda name : mock_scheduler },
606
607
):
607
608
with self .assertRaisesRegex (RuntimeError , "foobar" ):
608
- with get_runner () as runner :
609
+ with get_runner ():
609
610
raise RuntimeError ("foobar" )
610
611
611
612
def test_runner_manual_close (self , _ ) -> None :
0 commit comments