Skip to content

Commit 3a3772c

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API (#1080)
Summary: Additionally fill hostname, resource (cpu, memMB), image, entrypoint in `describe_squeue` for each role/replica. Reviewed By: d4l3k Differential Revision: D76485112
1 parent 50b8c02 commit 3a3772c

File tree

3 files changed

+1803
-104
lines changed

3 files changed

+1803
-104
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 110 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
from dataclasses import dataclass
2222
from datetime import datetime
23+
from subprocess import CalledProcessError, PIPE
2324
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2425

2526
import torchx
@@ -39,6 +40,7 @@
3940
macros,
4041
NONE,
4142
ReplicaStatus,
43+
Resource,
4244
Role,
4345
RoleStatus,
4446
runopts,
@@ -66,6 +68,11 @@
6668
"TIMEOUT": AppState.FAILED,
6769
}
6870

71+
72+
def appstate_from_slurm_state(slurm_state: str) -> AppState:
73+
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74+
75+
6976
SBATCH_JOB_OPTIONS = {
7077
"comment",
7178
"mail-user",
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490

484491
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
485492
try:
486-
return self._describe_sacct(app_id)
487-
except subprocess.CalledProcessError:
488493
return self._describe_squeue(app_id)
494+
except CalledProcessError as e:
495+
# NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496+
# if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497+
# in this case, fall back to the less descriptive but more persistent sacct
498+
# (slurm cluster must have accounting storage enabled for sacct to work)
499+
log.info(
500+
"unable to get job info for `{}` with `squeue` ({}), trying `sacct`".format(
501+
app_id, e.stderr
502+
)
503+
)
504+
return self._describe_sacct(app_id)
489505

490506
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
491-
p = subprocess.run(
492-
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
493-
)
494-
output = p.stdout.decode("utf-8").split("\n")
507+
try:
508+
output = subprocess.check_output(
509+
["sacct", "--parsable2", "-j", app_id],
510+
stderr=PIPE,
511+
encoding="utf-8",
512+
).split("\n")
513+
except CalledProcessError as e:
514+
log.info(
515+
"unable to get job info for `{}` with `sacct` ({})".format(
516+
app_id, e.stderr
517+
)
518+
)
519+
return None
520+
495521
if len(output) <= 1:
496522
return None
497523

@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537

512538
state = row["State"]
513539
msg = state
514-
state_enum = SLURM_STATES.get(state)
515-
assert (
516-
state_enum
517-
), f"failed to translate slurm state {state} to torchx state"
518-
app_state = state_enum
540+
app_state = appstate_from_slurm_state(state)
519541

520542
role, _, replica_id = row["JobName"].rpartition("-")
521543
if not replica_id or not role:
@@ -540,46 +562,93 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562
msg=msg,
541563
)
542564

543-
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
544-
p = subprocess.run(
545-
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
565+
def _describe_squeue(self, app_id: str) -> DescribeAppResponse:
566+
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567+
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568+
output = subprocess.check_output(
569+
["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
546570
)
547-
output_json = json.loads(p.stdout.decode("utf-8"))
548571

549-
roles = {}
550-
roles_statuses = {}
551-
msg = ""
552-
app_state = AppState.UNKNOWN
553-
for job in output_json["jobs"]:
554-
state = job["job_state"][0]
555-
msg = state
556-
state_enum = SLURM_STATES.get(state)
557-
assert (
558-
state_enum
559-
), f"failed to translate slurm state {state} to torchx state"
560-
app_state = state_enum
572+
output_json = json.loads(output)
573+
jobs = output_json["jobs"]
561574

562-
role, _, replica_id = job["name"].rpartition("-")
563-
if not replica_id or not role:
564-
# name should always have at least 3 parts but sometimes sacct
565-
# is slow to update
566-
continue
567-
if role not in roles:
568-
roles[role] = Role(name=role, num_replicas=0, image="")
569-
roles_statuses[role] = RoleStatus(role, [])
570-
roles[role].num_replicas += 1
571-
roles_statuses[role].replicas.append(
572-
ReplicaStatus(
573-
id=int(replica_id), role=role, state=app_state, hostname=""
575+
roles: dict[str, Role] = {}
576+
roles_statuses: dict[str, RoleStatus] = {}
577+
state = AppState.UNKNOWN
578+
579+
for job in jobs:
580+
# job name is of the form "{role_name}-{replica_id}"
581+
role_name, _, replica_id = job["name"].rpartition("-")
582+
583+
entrypoint = job["command"]
584+
image = job["current_working_directory"]
585+
state = appstate_from_slurm_state(job["job_state"][0])
586+
587+
job_resources = job["job_resources"]
588+
589+
role = roles.setdefault(
590+
role_name,
591+
Role(
592+
name=role_name,
593+
image=image,
594+
entrypoint=entrypoint,
595+
num_replicas=0,
574596
),
575597
)
598+
role_status = roles_statuses.setdefault(
599+
role_name,
600+
RoleStatus(role_name, replicas=[]),
601+
)
602+
603+
if state == AppState.PENDING:
604+
# NOTE: torchx launched jobs points to exactly one host
605+
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
606+
hostname = job_resources["scheduled_nodes"]
607+
role.num_replicas += 1
608+
role_status.replicas.append(
609+
ReplicaStatus(
610+
id=int(replica_id),
611+
role=role_name,
612+
state=state,
613+
hostname=hostname,
614+
)
615+
)
616+
else: # state == AppState.RUNNING
617+
# NOTE: torchx schedules on slurm with sbatch + heterogenous job
618+
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
619+
# but we deal with jobs that have not been launched with torchx
620+
# which can have multiple hosts per sub-job (count them as replicas)
621+
node_infos = job_resources.get("allocated_nodes", [])
622+
623+
for node_info in node_infos:
624+
# NOTE: we expect resource specs for all the nodes to be the same
625+
# NOTE: use allocated (not used/requested) memory since
626+
# users may only specify --cpu, in which case slurm
627+
# uses the (system) configured {mem-per-cpu} * {cpus}
628+
# to allocate memory.
629+
# NOTE: getting gpus is tricky because it modeled as a trackable-resource
630+
# or not configured at all (use total-cpu-on-host as proxy for gpus)
631+
cpu = int(node_info["cpus_used"])
632+
memMB = int(node_info["memory_allocated"])
633+
634+
hostname = node_info["nodename"]
635+
636+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
637+
role.num_replicas += 1
638+
role_status.replicas.append(
639+
ReplicaStatus(
640+
id=int(replica_id),
641+
role=role_name,
642+
state=state,
643+
hostname=hostname,
644+
)
645+
)
576646

577647
return DescribeAppResponse(
578648
app_id=app_id,
579649
roles=list(roles.values()),
580650
roles_statuses=list(roles_statuses.values()),
581-
state=app_state,
582-
msg=msg,
651+
state=state,
583652
)
584653

585654
def log_iter(

0 commit comments

Comments
 (0)