10
10
import datetime
11
11
import os
12
12
from contextlib import contextmanager
13
- from typing import Generator , List , Mapping , Optional
13
+ from typing import cast , Generator , List , Mapping , Optional
14
14
from unittest .mock import MagicMock , patch
15
15
16
16
from torchx .runner import get_runner , Runner
17
+ from torchx .schedulers import SchedulerFactory
17
18
from torchx .schedulers .api import DescribeAppResponse , ListAppResponse , Scheduler
18
19
from torchx .schedulers .local_scheduler import (
20
+ create_scheduler ,
19
21
LocalDirectoryImageProvider ,
20
- LocalScheduler ,
21
22
)
22
23
from torchx .specs import AppDryRunInfo , CfgVal
23
24
from torchx .specs .api import (
@@ -64,7 +65,7 @@ def setUp(self) -> None:
64
65
def get_runner (self ) -> Generator [Runner , None , None ]:
65
66
with Runner (
66
67
SESSION_NAME ,
67
- scheduler_factories = {"local_dir" : LocalScheduler },
68
+ scheduler_factories = {"local_dir" : cast ( SchedulerFactory , create_scheduler ) },
68
69
scheduler_params = {
69
70
"image_provider_class" : LocalDirectoryImageProvider ,
70
71
},
@@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:
79
80
80
81
def test_validate_no_resource (self , _ ) -> None :
81
82
with self .get_runner () as runner :
83
+ role = Role (
84
+ "no resource" ,
85
+ image = "no_image" ,
86
+ entrypoint = "echo" ,
87
+ args = ["hello_world" ],
88
+ )
89
+ app = AppDef ("no resource" , roles = [role ])
82
90
with self .assertRaises (ValueError ):
83
- role = Role (
84
- "no resource" ,
85
- image = "no_image" ,
86
- entrypoint = "echo" ,
87
- args = ["hello_world" ],
88
- )
89
- app = AppDef ("no resource" , roles = [role ])
90
91
runner .run (app , scheduler = "local_dir" )
91
92
92
93
def test_validate_invalid_replicas (self , _ ) -> None :
@@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
129
130
}
130
131
with Runner (
131
132
name = SESSION_NAME ,
132
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
133
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
133
134
) as runner :
134
135
role = Role (
135
136
name = "touch" ,
@@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
149
150
scheduler_mock = MagicMock ()
150
151
with Runner (
151
152
name = SESSION_NAME ,
152
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
153
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
153
154
) as runner :
154
155
role1 = Role (
155
156
name = "echo1" ,
@@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
178
179
expected_parent_run_id = "123"
179
180
with Runner (
180
181
name = SESSION_NAME ,
181
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
182
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
182
183
) as runner :
183
184
role1 = Role (
184
185
name = "echo1" ,
@@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None
217
218
218
219
with Runner (
219
220
name = SESSION_NAME ,
220
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
221
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
221
222
) as runner :
222
223
role1 = Role (
223
224
name = "echo1" ,
@@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
265
266
266
267
with Runner (
267
268
name = SESSION_NAME ,
268
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
269
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
269
270
) as runner :
270
271
role1 = Role (
271
272
name = "echo1" ,
@@ -333,8 +334,10 @@ def build_workspace_and_update_role(
333
334
name = SESSION_NAME ,
334
335
# pyre-fixme[6]: scheduler factory type
335
336
scheduler_factories = {
336
- "no-build-img" : lambda name : TestScheduler (build_new_img = False ),
337
- "builds-img" : lambda name : TestScheduler (build_new_img = True ),
337
+ "no-build-img" : lambda name , ** kwargs : TestScheduler (
338
+ build_new_img = False
339
+ ),
340
+ "builds-img" : lambda name , ** kwargs : TestScheduler (build_new_img = True ),
338
341
},
339
342
) as runner :
340
343
app = AppDef (
@@ -371,7 +374,7 @@ def test_describe(self, _) -> None:
371
374
name = "sleep" ,
372
375
image = str (self .tmpdir ),
373
376
resource = resource .SMALL ,
374
- entrypoint = "sleep.sh " ,
377
+ entrypoint = "sleep" ,
375
378
args = ["60" ],
376
379
)
377
380
app = AppDef ("sleeper" , roles = [role ])
@@ -387,7 +390,7 @@ def test_status(self, _) -> None:
387
390
name = "sleep" ,
388
391
image = str (self .tmpdir ),
389
392
resource = resource .SMALL ,
390
- entrypoint = "sleep.sh " ,
393
+ entrypoint = "sleep" ,
391
394
args = ["60" ],
392
395
)
393
396
app = AppDef ("sleeper" , roles = [role ])
@@ -414,7 +417,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
414
417
415
418
with Runner (
416
419
name = "test_ui_url_session" ,
417
- scheduler_factories = {"local_dir" : lambda name : mock_scheduler },
420
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
418
421
) as runner :
419
422
role = Role (
420
423
"ignored" ,
@@ -438,7 +441,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
438
441
439
442
with Runner (
440
443
name = "test_structured_msg" ,
441
- scheduler_factories = {"local_dir" : lambda name : mock_scheduler },
444
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
442
445
) as runner :
443
446
role = Role (
444
447
"ignored" ,
@@ -485,7 +488,7 @@ def test_log_lines(self, _) -> None:
485
488
486
489
with Runner (
487
490
name = SESSION_NAME ,
488
- scheduler_factories = {"local_dir" : lambda name : scheduler_mock },
491
+ scheduler_factories = {"local_dir" : lambda name , ** kwargs : scheduler_mock },
489
492
) as runner :
490
493
role_name = "trainer"
491
494
replica_id = 2
@@ -529,7 +532,7 @@ def test_list(self, _) -> None:
529
532
]
530
533
with Runner (
531
534
name = SESSION_NAME ,
532
- scheduler_factories = {"kubernetes" : lambda name : scheduler_mock },
535
+ scheduler_factories = {"kubernetes" : lambda name , ** kwargs : scheduler_mock },
533
536
) as runner :
534
537
apps = runner .list ("kubernetes" )
535
538
self .assertEqual (apps , apps_expected )
@@ -541,8 +544,8 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
541
544
json_dumps_mock .return_value = "{}"
542
545
local_sched_mock = MagicMock ()
543
546
scheduler_factories = {
544
- "local_dir" : lambda name : local_dir_sched_mock ,
545
- "local" : lambda name : local_sched_mock ,
547
+ "local_dir" : lambda name , ** kwargs : local_dir_sched_mock ,
548
+ "local" : lambda name , ** kwargs : local_sched_mock ,
546
549
}
547
550
with Runner (
548
551
name = "test_session" , scheduler_factories = scheduler_factories
@@ -576,8 +579,8 @@ def test_run_from_module(self, _: str) -> None:
576
579
def test_run_from_file_no_function_found (self , _ ) -> None :
577
580
local_sched_mock = MagicMock ()
578
581
schedulers = {
579
- "local_dir" : lambda name : local_sched_mock ,
580
- "local" : lambda name : local_sched_mock ,
582
+ "local_dir" : lambda name , ** kwargs : local_sched_mock ,
583
+ "local" : lambda name , ** kwargs : local_sched_mock ,
581
584
}
582
585
with Runner (name = "test_session" , scheduler_factories = schedulers ) as runner :
583
586
component_path = get_full_path ("distributed.py" )
@@ -591,7 +594,7 @@ def test_runner_context_manager(self, _) -> None:
591
594
mock_scheduler = MagicMock ()
592
595
with patch (
593
596
GET_SCHEDULER_FACTORIES ,
594
- return_value = {"local_dir" : lambda name : mock_scheduler },
597
+ return_value = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
595
598
):
596
599
with get_runner () as runner :
597
600
# force schedulers to load
@@ -602,17 +605,17 @@ def test_runner_context_manager_with_error(self, _) -> None:
602
605
mock_scheduler = MagicMock ()
603
606
with patch (
604
607
GET_SCHEDULER_FACTORIES ,
605
- return_value = {"local_dir" : lambda name : mock_scheduler },
608
+ return_value = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
606
609
):
607
610
with self .assertRaisesRegex (RuntimeError , "foobar" ):
608
- with get_runner () as runner :
611
+ with get_runner ():
609
612
raise RuntimeError ("foobar" )
610
613
611
614
def test_runner_manual_close (self , _ ) -> None :
612
615
mock_scheduler = MagicMock ()
613
616
with patch (
614
617
GET_SCHEDULER_FACTORIES ,
615
- return_value = {"local_dir" : lambda name : mock_scheduler },
618
+ return_value = {"local_dir" : lambda name , ** kwargs : mock_scheduler },
616
619
):
617
620
runner = get_runner ()
618
621
# force schedulers to load
0 commit comments