Skip to content

Commit 707637e

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/scheduler) Fill hostnames for each replica in slurm scheduler's describe API
Summary: Use `scontrol` to implement the describe API that fills out the hostnames for each replica. Differential Revision: D76485112
1 parent 2124818 commit 707637e

File tree

3 files changed

+1746
-19
lines changed

3 files changed

+1746
-19
lines changed

torchx/schedulers/slurm_scheduler.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
from dataclasses import dataclass
2222
from datetime import datetime
23+
from subprocess import CalledProcessError, PIPE
2324
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2425

2526
import torchx
@@ -66,6 +67,11 @@
6667
"TIMEOUT": AppState.FAILED,
6768
}
6869

70+
71+
def appstate_from_slurm_state(slurm_state: str) -> AppState:
72+
return SLURM_STATES.get(slurm_state, AppState.UNKNOWN)
73+
74+
6975
SBATCH_JOB_OPTIONS = {
7076
"comment",
7177
"mail-user",
@@ -482,10 +488,82 @@ def _cancel_existing(self, app_id: str) -> None:
482488
subprocess.run(["scancel", app_id], check=True)
483489

484490
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
485-
try:
486-
return self._describe_sacct(app_id)
487-
except subprocess.CalledProcessError:
488-
return self._describe_squeue(app_id)
491+
# fallback to using different slurm commands for describing the job
492+
for describe in [
493+
self._describe_scontrol, # NOTE: only scontrol fills hostnames
494+
self._describe_sacct,
495+
self._describe_squeue,
496+
]:
497+
try:
498+
return describe(app_id)
499+
except CalledProcessError:
500+
continue
501+
502+
def _describe_scontrol(self, app_id: str) -> Optional[DescribeAppResponse]:
503+
# NOTE: app_id for slurm_scheduler is the job_id (not the heterogenous_job_id).
504+
# For heterogeneous jobs, querying slurm by the base job id returns all the
505+
# "sub-jobs" in it.
506+
# We launch each role's replica on its own srun command where the job_name is set
507+
# to `{role.name}-{replica_id}` (e.g. `worker-0`, `worker-1`, ...).
508+
# So each sub-job maps to a replica in the role.
509+
510+
output = subprocess.check_output(
511+
["scontrol", "show", "--json", "job", app_id], stderr=PIPE, encoding="utf-8"
512+
)
513+
output_json = json.loads(output)
514+
jobs = output_json["jobs"]
515+
if not jobs:
516+
# job either finished or does not exist
517+
return None
518+
519+
roles: dict[str, Role] = {}
520+
roles_statuses: dict[str, RoleStatus] = {}
521+
state = AppState.UNKNOWN
522+
523+
for job in jobs:
524+
# job name is of the form "{role_name}-{replica_id}"
525+
role_name, _, replica_id = job["name"].rpartition("-")
526+
527+
image = job["current_working_directory"]
528+
entrypoint = job["command"]
529+
state = appstate_from_slurm_state(job["job_state"][0])
530+
job_resources = job["job_resources"]
531+
532+
# nodes is a a hostlist expression (e.g. slurm-compute-node[200-210,212])
533+
# but we schedule a job per replica so will always be a single host
534+
hostname = job_resources["nodes"]
535+
536+
role = roles.setdefault(
537+
role_name,
538+
Role(
539+
name=role_name,
540+
image=image,
541+
entrypoint=entrypoint,
542+
num_replicas=0,
543+
),
544+
)
545+
role.num_replicas += 1
546+
547+
role_status = roles_statuses.setdefault(
548+
role_name,
549+
RoleStatus(role_name, replicas=[]),
550+
)
551+
552+
role_status.replicas.append(
553+
ReplicaStatus(
554+
id=int(replica_id),
555+
role=role_name,
556+
state=state,
557+
hostname=hostname,
558+
)
559+
)
560+
561+
return DescribeAppResponse(
562+
app_id=app_id,
563+
roles=list(roles.values()),
564+
roles_statuses=list(roles_statuses.values()),
565+
state=state,
566+
)
489567

490568
def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
491569
p = subprocess.run(
@@ -511,11 +589,7 @@ def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
511589

512590
state = row["State"]
513591
msg = state
514-
state_enum = SLURM_STATES.get(state)
515-
assert (
516-
state_enum
517-
), f"failed to translate slurm state {state} to torchx state"
518-
app_state = state_enum
592+
app_state = appstate_from_slurm_state(state)
519593

520594
role, _, replica_id = row["JobName"].rpartition("-")
521595
if not replica_id or not role:
@@ -553,11 +627,7 @@ def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
553627
for job in output_json["jobs"]:
554628
state = job["job_state"][0]
555629
msg = state
556-
state_enum = SLURM_STATES.get(state)
557-
assert (
558-
state_enum
559-
), f"failed to translate slurm state {state} to torchx state"
560-
app_state = state_enum
630+
app_state = appstate_from_slurm_state(state)
561631

562632
role, _, replica_id = job["name"].rpartition("-")
563633
if not replica_id or not role:

0 commit comments

Comments
 (0)