20
20
import tempfile
21
21
from dataclasses import dataclass
22
22
from datetime import datetime
23
+ from subprocess import CalledProcessError , PIPE
23
24
from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
24
25
25
26
import torchx
66
67
"TIMEOUT" : AppState .FAILED ,
67
68
}
68
69
70
+
71
+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
72
+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
73
+
74
+
69
75
SBATCH_JOB_OPTIONS = {
70
76
"comment" ,
71
77
"mail-user" ,
@@ -482,10 +488,82 @@ def _cancel_existing(self, app_id: str) -> None:
482
488
subprocess .run (["scancel" , app_id ], check = True )
483
489
484
490
def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
485
- try :
486
- return self ._describe_sacct (app_id )
487
- except subprocess .CalledProcessError :
488
- return self ._describe_squeue (app_id )
491
+ # fallback to using different slurm commands for describing the job
492
+ for describe in [
493
+ self ._describe_scontrol , # NOTE: only scontrol fills hostnames
494
+ self ._describe_sacct ,
495
+ self ._describe_squeue ,
496
+ ]:
497
+ try :
498
+ return describe (app_id )
499
+ except CalledProcessError :
500
+ continue
501
+
502
+ def _describe_scontrol (self , app_id : str ) -> Optional [DescribeAppResponse ]:
503
+ # NOTE: app_id for slurm_scheduler is the job_id (not the heterogenous_job_id).
504
+ # For heterogeneous jobs, querying slurm by the base job id returns all the
505
+ # "sub-jobs" in it.
506
+ # We launch each role's replica on its own srun command where the job_name is set
507
+ # to `{role.name}-{replica_id}` (e.g. `worker-0`, `worker-1`, ...).
508
+ # So each sub-job maps to a replica in the role.
509
+
510
+ output = subprocess .check_output (
511
+ ["scontrol" , "show" , "--json" , "job" , app_id ], stderr = PIPE , encoding = "utf-8"
512
+ )
513
+ output_json = json .loads (output )
514
+ jobs = output_json ["jobs" ]
515
+ if not jobs :
516
+ # job either finished or does not exist
517
+ return None
518
+
519
+ roles : dict [str , Role ] = {}
520
+ roles_statuses : dict [str , RoleStatus ] = {}
521
+ state = AppState .UNKNOWN
522
+
523
+ for job in jobs :
524
+ # job name is of the form "{role_name}-{replica_id}"
525
+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
526
+
527
+ image = job ["current_working_directory" ]
528
+ entrypoint = job ["command" ]
529
+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
530
+ job_resources = job ["job_resources" ]
531
+
532
+ # nodes is a a hostlist expression (e.g. slurm-compute-node[200-210,212])
533
+ # but we schedule a job per replica so will always be a single host
534
+ hostname = job_resources ["nodes" ]
535
+
536
+ role = roles .setdefault (
537
+ role_name ,
538
+ Role (
539
+ name = role_name ,
540
+ image = image ,
541
+ entrypoint = entrypoint ,
542
+ num_replicas = 0 ,
543
+ ),
544
+ )
545
+ role .num_replicas += 1
546
+
547
+ role_status = roles_statuses .setdefault (
548
+ role_name ,
549
+ RoleStatus (role_name , replicas = []),
550
+ )
551
+
552
+ role_status .replicas .append (
553
+ ReplicaStatus (
554
+ id = int (replica_id ),
555
+ role = role_name ,
556
+ state = state ,
557
+ hostname = hostname ,
558
+ )
559
+ )
560
+
561
+ return DescribeAppResponse (
562
+ app_id = app_id ,
563
+ roles = list (roles .values ()),
564
+ roles_statuses = list (roles_statuses .values ()),
565
+ state = state ,
566
+ )
489
567
490
568
def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491
569
p = subprocess .run (
@@ -511,11 +589,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511
589
512
590
state = row ["State" ]
513
591
msg = state
514
- state_enum = SLURM_STATES .get (state )
515
- assert (
516
- state_enum
517
- ), f"failed to translate slurm state { state } to torchx state"
518
- app_state = state_enum
592
+ app_state = appstate_from_slurm_state (state )
519
593
520
594
role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521
595
if not replica_id or not role :
@@ -553,11 +627,7 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
553
627
for job in output_json ["jobs" ]:
554
628
state = job ["job_state" ][0 ]
555
629
msg = state
556
- state_enum = SLURM_STATES .get (state )
557
- assert (
558
- state_enum
559
- ), f"failed to translate slurm state { state } to torchx state"
560
- app_state = state_enum
630
+ app_state = appstate_from_slurm_state (state )
561
631
562
632
role , _ , replica_id = job ["name" ].rpartition ("-" )
563
633
if not replica_id or not role :
0 commit comments