Skip to content

Commit 811a65f

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API (#1080)
Summary: Additionally fill hostname, resource (cpu, memMB), image, entrypoint in `describe_squeue` for each role/replica. Reviewed By: d4l3k Differential Revision: D76485112
1 parent 50b8c02 commit 811a65f

File tree

5 files changed

+1826
-107
lines changed

5 files changed

+1826
-107
lines changed

.github/workflows/slurm-local-integration-tests.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ on:
66
- main
77
pull_request:
88

9+
910
env:
10-
SLURM_VERSION: 21.08.6
11+
# slurm tag should be one of https://github.com/SchedMD/slurm/tags
12+
SLURM_TAG: slurm-23-11-11-1
13+
SLURM_VERSION: 23.11.11
1114

1215
jobs:
1316
slurm:

torchx/cli/cmd_log.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def get_logs(
9797
display_waiting = True
9898
while True:
9999
status = runner.status(app_handle)
100+
print(f"*************** app status: {status}")
101+
if status:
102+
print(f"*************** app state: {state}")
100103
if status and is_started(status.state):
101104
break
102105
elif display_waiting:

torchx/schedulers/slurm_scheduler.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
from dataclasses import dataclass
2222
from datetime import datetime
23+
from subprocess import CalledProcessError, PIPE
2324
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2425

2526
import torchx
@@ -39,6 +40,7 @@
3940
macros,
4041
NONE,
4142
ReplicaStatus,
43+
Resource,
4244
Role,
4345
RoleStatus,
4446
runopts,
@@ -66,6 +68,11 @@
6668
"TIMEOUT": AppState.FAILED,
6769
}
6870

71+
72+
def appstate_from_slurm_state(slurm_state: str) -> AppState:
73+
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
74+
75+
6976
SBATCH_JOB_OPTIONS = {
7077
"comment",
7178
"mail-user",
@@ -483,15 +490,34 @@ def _cancel_existing(self, app_id: str) -> None:
483490

484491
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
485492
try:
486-
return self._describe_sacct(app_id)
487-
except subprocess.CalledProcessError:
488493
return self._describe_squeue(app_id)
494+
except CalledProcessError as e:
495+
# NOTE: squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
496+
# if the job does not exist or has finished (e.g. not in PENDING or RUNNING states)
497+
# in this case, fall back to the less descriptive but more persistent sacct
498+
# (slurm cluster must have accounting storage enabled for sacct to work)
499+
log.info(
500+
"unable to get job info for `{}` with `squeue` ({}), trying `sacct`".format(
501+
app_id, e.stderr
502+
)
503+
)
504+
return self._describe_sacct(app_id)
489505

490506
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
491-
p = subprocess.run(
492-
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
493-
)
494-
output = p.stdout.decode("utf-8").split("\n")
507+
try:
508+
output = subprocess.check_output(
509+
["sacct", "--parsable2", "-j", app_id],
510+
stderr=PIPE,
511+
encoding="utf-8",
512+
).split("\n")
513+
except CalledProcessError as e:
514+
log.info(
515+
"unable to get job info for `{}` with `sacct` ({})".format(
516+
app_id, e.stderr
517+
)
518+
)
519+
return None
520+
495521
if len(output) <= 1:
496522
return None
497523

@@ -511,11 +537,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511537

512538
state = row["State"]
513539
msg = state
514-
state_enum = SLURM_STATES.get(state)
515-
assert (
516-
state_enum
517-
), f"failed to translate slurm state {state} to torchx state"
518-
app_state = state_enum
540+
app_state = appstate_from_slurm_state(state)
519541

520542
role, _, replica_id = row["JobName"].rpartition("-")
521543
if not replica_id or not role:
@@ -540,46 +562,107 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
540562
msg=msg,
541563
)
542564

543-
def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
544-
p = subprocess.run(
545-
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
565+
def _describe_squeue(self, app_id: str) -> DescribeAppResponse:
566+
# squeue errors out with 'slurm_load_jobs error: Invalid job id specified'
567+
# if the job does not exist or is finished (e.g. not in PENDING or RUNNING state)
568+
output = subprocess.check_output(
569+
["squeue", "--json", "-j", app_id], stderr=PIPE, encoding="utf-8"
546570
)
547-
output_json = json.loads(p.stdout.decode("utf-8"))
548-
549-
roles = {}
550-
roles_statuses = {}
551-
msg = ""
552-
app_state = AppState.UNKNOWN
553-
for job in output_json["jobs"]:
554-
state = job["job_state"][0]
555-
msg = state
556-
state_enum = SLURM_STATES.get(state)
557-
assert (
558-
state_enum
559-
), f"failed to translate slurm state {state} to torchx state"
560-
app_state = state_enum
561-
562-
role, _, replica_id = job["name"].rpartition("-")
563-
if not replica_id or not role:
564-
# name should always have at least 3 parts but sometimes sacct
565-
# is slow to update
566-
continue
567-
if role not in roles:
568-
roles[role] = Role(name=role, num_replicas=0, image="")
569-
roles_statuses[role] = RoleStatus(role, [])
570-
roles[role].num_replicas += 1
571-
roles_statuses[role].replicas.append(
572-
ReplicaStatus(
573-
id=int(replica_id), role=role, state=app_state, hostname=""
571+
output_json = json.loads(output)
572+
jobs = output_json["jobs"]
573+
574+
roles: dict[str, Role] = {}
575+
roles_statuses: dict[str, RoleStatus] = {}
576+
state = AppState.UNKNOWN
577+
578+
for job in jobs:
579+
# job name is of the form "{role_name}-{replica_id}"
580+
role_name, _, replica_id = job["name"].rpartition("-")
581+
582+
entrypoint = job["command"]
583+
image = job["current_working_directory"]
584+
state = appstate_from_slurm_state(job["job_state"][0])
585+
586+
job_resources = job["job_resources"]
587+
588+
role = roles.setdefault(
589+
role_name,
590+
Role(
591+
name=role_name,
592+
image=image,
593+
entrypoint=entrypoint,
594+
num_replicas=0,
574595
),
575596
)
597+
role_status = roles_statuses.setdefault(
598+
role_name,
599+
RoleStatus(role_name, replicas=[]),
600+
)
601+
602+
if state == AppState.PENDING:
603+
# NOTE: torchx launched jobs points to exactly one host
604+
# otherwise, scheduled_nodes could be a node list expression (eg. 'slurm-compute-node[0-20,21,45-47]')
605+
hostname = job_resources["scheduled_nodes"]
606+
role.num_replicas += 1
607+
role_status.replicas.append(
608+
ReplicaStatus(
609+
id=int(replica_id),
610+
role=role_name,
611+
state=state,
612+
hostname=hostname,
613+
)
614+
)
615+
else: # state == AppState.RUNNING
616+
# NOTE: torchx schedules on slurm with sbatch + heterogenous job
617+
# where each replica is a "sub-job" so `allocated_nodes` will always be 1
618+
# but we deal with jobs that have not been launched with torchx
619+
# which can have multiple hosts per sub-job (count them as replicas)
620+
node_infos = job_resources.get("allocated_nodes", [])
621+
622+
if not isinstance(node_infos, list):
623+
# NOTE: in some versions of slurm jobs[].job_resources.allocated_nodes
624+
# is not a list of individual nodes, but a map of the nodelist specs
625+
# in this case just use jobs[].job_resources.nodes
626+
hostname = job_resources.get("nodes")
627+
role.num_replicas += 1
628+
role_status.replicas.append(
629+
ReplicaStatus(
630+
id=int(replica_id),
631+
role=role_name,
632+
state=state,
633+
hostname=hostname,
634+
)
635+
)
636+
else:
637+
for node_info in node_infos:
638+
# NOTE: we expect resource specs for all the nodes to be the same
639+
# NOTE: use allocated (not used/requested) memory since
640+
# users may only specify --cpu, in which case slurm
641+
# uses the (system) configured {mem-per-cpu} * {cpus}
642+
# to allocate memory.
643+
# NOTE: getting gpus is tricky because it modeled as a trackable-resource
644+
# or not configured at all (use total-cpu-on-host as proxy for gpus)
645+
cpu = int(node_info["cpus_used"])
646+
memMB = int(node_info["memory_allocated"])
647+
648+
hostname = node_info["nodename"]
649+
650+
role.resource = Resource(cpu=cpu, memMB=memMB, gpu=-1)
651+
role.num_replicas += 1
652+
role_status.replicas.append(
653+
ReplicaStatus(
654+
id=int(replica_id),
655+
role=role_name,
656+
state=state,
657+
hostname=hostname,
658+
)
659+
)
576660

577661
return DescribeAppResponse(
578662
app_id=app_id,
579663
roles=list(roles.values()),
580664
roles_statuses=list(roles_statuses.values()),
581-
state=app_state,
582-
msg=msg,
665+
state=state,
583666
)
584667

585668
def log_iter(

0 commit comments

Comments
 (0)