20
20
import tempfile
21
21
from dataclasses import dataclass
22
22
from datetime import datetime
23
+ from subprocess import CalledProcessError , PIPE
23
24
from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
24
25
25
26
import torchx
39
40
macros ,
40
41
NONE ,
41
42
ReplicaStatus ,
43
+ Resource ,
42
44
Role ,
43
45
RoleStatus ,
44
46
runopts ,
66
68
"TIMEOUT" : AppState .FAILED ,
67
69
}
68
70
71
+
72
+ def appstate_from_slurm_state (slurm_state : str ) -> AppState :
73
+ return SLURM_STATES .get (slurm_state , AppState .UNKNOWN )
74
+
75
+
69
76
SBATCH_JOB_OPTIONS = {
70
77
"comment" ,
71
78
"mail-user" ,
@@ -482,16 +489,36 @@ def _cancel_existing(self, app_id: str) -> None:
482
489
subprocess .run (["scancel" , app_id ], check = True )
483
490
484
491
def describe (self , app_id : str ) -> Optional [DescribeAppResponse ]:
492
+ # NOTE: depending on the version of slurm, querying for job info
493
+ # with `squeue` for finished (or non-existent) jobs either:
494
+ # 1. errors out with 'slurm_load_jobs error: Invalid job id specified'
495
+ # 2. -- or -- squeue returns an empty jobs list
496
+ # in either case, fall back to the less descriptive but more persistent sacct
497
+ # (slurm cluster must have accounting storage enabled for sacct to work)
485
498
try :
486
- return self ._describe_sacct (app_id )
487
- except subprocess .CalledProcessError :
488
- return self ._describe_squeue (app_id )
499
+ if desc := self ._describe_squeue (app_id ):
500
+ return desc
501
+ except CalledProcessError as e :
502
+ log .info (
503
+ f"unable to get job info for `{ app_id } ` with `squeue` ({ e .stderr } ), trying `sacct`"
504
+ )
505
+ return self ._describe_sacct (app_id )
489
506
490
507
def _describe_sacct (self , app_id : str ) -> Optional [DescribeAppResponse ]:
491
- p = subprocess .run (
492
- ["sacct" , "--parsable2" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
493
- )
494
- output = p .stdout .decode ("utf-8" ).split ("\n " )
508
+ try :
509
+ output = subprocess .check_output (
510
+ ["sacct" , "--parsable2" , "-j" , app_id ],
511
+ stderr = PIPE ,
512
+ encoding = "utf-8" ,
513
+ ).split ("\n " )
514
+ except CalledProcessError as e :
515
+ log .info (
516
+ "unable to get job info for `{}` with `sacct` ({})" .format (
517
+ app_id , e .stderr
518
+ )
519
+ )
520
+ return None
521
+
495
522
if len (output ) <= 1 :
496
523
return None
497
524
@@ -511,11 +538,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511
538
512
539
state = row ["State" ]
513
540
msg = state
514
- state_enum = SLURM_STATES .get (state )
515
- assert (
516
- state_enum
517
- ), f"failed to translate slurm state { state } to torchx state"
518
- app_state = state_enum
541
+ app_state = appstate_from_slurm_state (state )
519
542
520
543
role , _ , replica_id = row ["JobName" ].rpartition ("-" )
521
544
if not replica_id or not role :
@@ -540,46 +563,107 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540
563
msg = msg ,
541
564
)
542
565
543
- def _describe_squeue (self , app_id : str ) -> Optional [DescribeAppResponse ]:
544
- p = subprocess .run (
545
- ["squeue" , "--json" , "-j" , app_id ], stdout = subprocess .PIPE , check = True
566
+ def _describe_squeue (self , app_id : str ) -> DescribeAppResponse :
567
+ # squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
568
+ # if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
569
+ output = subprocess .check_output (
570
+ ["squeue" , "--json" , "-j" , app_id ], stderr = PIPE , encoding = "utf-8"
546
571
)
547
- output_json = json .loads (p .stdout .decode ("utf-8" ))
548
-
549
- roles = {}
550
- roles_statuses = {}
551
- msg = ""
552
- app_state = AppState .UNKNOWN
553
- for job in output_json ["jobs" ]:
554
- state = job ["job_state" ][0 ]
555
- msg = state
556
- state_enum = SLURM_STATES .get (state )
557
- assert (
558
- state_enum
559
- ), f"failed to translate slurm state { state } to torchx state"
560
- app_state = state_enum
561
-
562
- role , _ , replica_id = job ["name" ].rpartition ("-" )
563
- if not replica_id or not role :
564
- # name should always have at least 3 parts but sometimes sacct
565
- # is slow to update
566
- continue
567
- if role not in roles :
568
- roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
569
- roles_statuses [role ] = RoleStatus (role , [])
570
- roles [role ].num_replicas += 1
571
- roles_statuses [role ].replicas .append (
572
- ReplicaStatus (
573
- id = int (replica_id ), role = role , state = app_state , hostname = ""
572
+ output_json = json .loads (output )
573
+ jobs = output_json ["jobs" ]
574
+
575
+ roles : dict [str , Role ] = {}
576
+ roles_statuses : dict [str , RoleStatus ] = {}
577
+ state = AppState .UNKNOWN
578
+
579
+ for job in jobs :
580
+ # job name is of the form "{role_name}-{replica_id}"
581
+ role_name , _ , replica_id = job ["name" ].rpartition ("-" )
582
+
583
+ entrypoint = job ["command" ]
584
+ image = job ["current_working_directory" ]
585
+ state = appstate_from_slurm_state (job ["job_state" ][0 ])
586
+
587
+ job_resources = job ["job_resources" ]
588
+
589
+ role = roles .setdefault (
590
+ role_name ,
591
+ Role (
592
+ name = role_name ,
593
+ image = image ,
594
+ entrypoint = entrypoint ,
595
+ num_replicas = 0 ,
574
596
),
575
597
)
598
+ role_status = roles_statuses .setdefault (
599
+ role_name ,
600
+ RoleStatus (role_name , replicas = []),
601
+ )
602
+
603
+ if state == AppState .PENDING :
604
+ # NOTE: torchx launched jobs points to exactly one host
605
+ # otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
606
+ hostname = job_resources ["scheduled_nodes" ]
607
+ role .num_replicas += 1
608
+ role_status .replicas .append (
609
+ ReplicaStatus (
610
+ id = int (replica_id ),
611
+ role = role_name ,
612
+ state = state ,
613
+ hostname = hostname ,
614
+ )
615
+ )
616
+ else : # state == AppState.RUNNING
617
+ # NOTE: torchx schedules on slurm with sbatch + heterogenous job
618
+ # where each replica is a "sub-job" so `allocated_nodes` will always be 1
619
+ # but we deal with jobs that have not been launched with torchx
620
+ # which can have multiple hosts per sub-job (count them as replicas)
621
+ node_infos = job_resources .get ("allocated_nodes" , [])
622
+
623
+ if not isinstance (node_infos , list ):
624
+ # NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
625
+ # is not a list of individual nodes, but a map of the nodelist specs
626
+ # in this case just use jobs[].job_resources.nodes
627
+ hostname = job_resources .get ("nodes" )
628
+ role .num_replicas += 1
629
+ role_status .replicas .append (
630
+ ReplicaStatus (
631
+ id = int (replica_id ),
632
+ role = role_name ,
633
+ state = state ,
634
+ hostname = hostname ,
635
+ )
636
+ )
637
+ else :
638
+ for node_info in node_infos :
639
+ # NOTE: we expect resource specs for all the nodes to be the same
640
+ # NOTE: use allocated (not used/requested) memory since
641
+ # users may only specify --cpu, in which case slurm
642
+ # uses the (system) configured {mem-per-cpu} * {cpus}
643
+ # to allocate memory.
644
+ # NOTE: getting gpus is tricky because it modeled as a trackable-resource
645
+ # or not configured at all (use total-cpu-on-host as proxy for gpus)
646
+ cpu = int (node_info ["cpus_used" ])
647
+ memMB = int (node_info ["memory_allocated" ])
648
+
649
+ hostname = node_info ["nodename" ]
650
+
651
+ role .resource = Resource (cpu = cpu , memMB = memMB , gpu = - 1 )
652
+ role .num_replicas += 1
653
+ role_status .replicas .append (
654
+ ReplicaStatus (
655
+ id = int (replica_id ),
656
+ role = role_name ,
657
+ state = state ,
658
+ hostname = hostname ,
659
+ )
660
+ )
576
661
577
662
return DescribeAppResponse (
578
663
app_id = app_id ,
579
664
roles = list (roles .values ()),
580
665
roles_statuses = list (roles_statuses .values ()),
581
- state = app_state ,
582
- msg = msg ,
666
+ state = state ,
583
667
)
584
668
585
669
def log_iter (
0 commit comments