Skip to content

slurm: support clusters without sacct #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions torchx/cli/cmd_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

HANDLE_HEADER = "APP HANDLE"
STATUS_HEADER = "APP STATUS"
NAME_HEADER = "APP NAME"


class CmdList(SubCommand):
Expand All @@ -39,5 +40,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
def run(self, args: argparse.Namespace) -> None:
with get_runner() as runner:
apps = runner.list(args.scheduler)
apps_data = [[app.app_handle, str(app.state)] for app in apps]
print(tabulate(apps_data, headers=[HANDLE_HEADER, STATUS_HEADER]))
apps_data = [[app.app_handle, app.name, str(app.state)] for app in apps]
print(
tabulate(apps_data, headers=[HANDLE_HEADER, NAME_HEADER, STATUS_HEADER])
)
1 change: 1 addition & 0 deletions torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ListAppResponse:
app_id: str
state: AppState
app_handle: str = "<NOT_SET>"
name: str = ""

# Implementing __hash__() makes ListAppResponse hashable which makes
# it easier to check if a ListAppResponse object exists in a list of
Expand Down
86 changes: 86 additions & 0 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ def _cancel_existing(self, app_id: str) -> None:
subprocess.run(["scancel", app_id], check=True)

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
try:
return self._describe_sacct(app_id)
except subprocess.CalledProcessError:
return self._describe_squeue(app_id)

def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]:
p = subprocess.run(
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
)
Expand Down Expand Up @@ -534,6 +540,48 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
msg=msg,
)

def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]:
p = subprocess.run(
["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True
)
output_json = json.loads(p.stdout.decode("utf-8"))

roles = {}
roles_statuses = {}
msg = ""
app_state = AppState.UNKNOWN
for job in output_json["jobs"]:
state = job["job_state"][0]
msg = state
state_enum = SLURM_STATES.get(state)
assert (
state_enum
), f"failed to translate slurm state {state} to torchx state"
app_state = state_enum

role, _, replica_id = job["name"].rpartition("-")
if not replica_id or not role:
# name should always have at least 3 parts but sometimes sacct
# is slow to update
continue
if role not in roles:
roles[role] = Role(name=role, num_replicas=0, image="")
roles_statuses[role] = RoleStatus(role, [])
roles[role].num_replicas += 1
roles_statuses[role].replicas.append(
ReplicaStatus(
id=int(replica_id), role=role, state=app_state, hostname=""
),
)

return DescribeAppResponse(
app_id=app_id,
roles=list(roles.values()),
roles_statuses=list(roles_statuses.values()),
state=app_state,
msg=msg,
)

def log_iter(
self,
app_id: str,
Expand Down Expand Up @@ -574,6 +622,12 @@ def log_iter(
return iterator

def list(self) -> List[ListAppResponse]:
try:
return self._list_sacct()
except subprocess.CalledProcessError:
return self._list_squeue()

def _list_sacct(self) -> List[ListAppResponse]:
# By default sacct only returns accounting information of jobs launched on the current day
# To return all jobs launched, set starttime to one second past unix epoch time
# Starttime will be modified when listing jobs by timeframe is supported
Expand All @@ -590,6 +644,38 @@ def list(self) -> List[ListAppResponse]:
for job in output_json["jobs"]
]

def _list_squeue(self) -> List[ListAppResponse]:
# if sacct isn't configured on the cluster, fallback to squeue which
# only has currently running jobs
p = subprocess.run(
["squeue", "--json"],
stdout=subprocess.PIPE,
check=True,
)
output_json = json.loads(p.stdout.decode("utf-8"))

out = []
for job in output_json["jobs"]:
job_id = job["job_id"]

het_job_id = job.get("het_job_id")
if (
het_job_id
and het_job_id["set"]
and het_job_id["number"] != job_id
and het_job_id["number"] > 0
):
continue

out.append(
ListAppResponse(
app_id=str(job["job_id"]),
state=SLURM_STATES[job["job_state"][0]],
name=job["name"],
)
)
return out


def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler:
return SlurmScheduler(
Expand Down
94 changes: 93 additions & 1 deletion torchx/schedulers/test/slurm_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,48 @@ def test_describe_running(self, run: MagicMock) -> None:
self.assertEqual(out.state, specs.AppState.RUNNING)

@patch("subprocess.run")
def test_list(self, run: MagicMock) -> None:
def test_describe_squeue(self, run: MagicMock) -> None:
run.return_value.stdout = b"""{
"jobs": [
{
"job_id": 1236,
"name": "foo-0",
"job_state": ["RUNNING"],
"het_job_id": {
"set": true,
"infinite": false,
"number": 1236
}
},
{
"job_id": 1237,
"name": "foo-1",
"job_state": ["RUNNING"],
"het_job_id": {
"set": true,
"infinite": false,
"number": 1236
}
}
]
}"""

scheduler = create_scheduler("foo")
out = scheduler._describe_squeue("54")

self.assertEqual(run.call_count, 1)
self.assertEqual(
run.call_args,
call(["squeue", "--json", "-j", "54"], stdout=subprocess.PIPE, check=True),
)

self.assertIsNotNone(out)
self.assertEqual(out.app_id, "54")
self.assertEqual(out.msg, "RUNNING")
self.assertEqual(out.state, specs.AppState.RUNNING)

@patch("subprocess.run")
def test_list_sacct(self, run: MagicMock) -> None:
run.return_value.stdout = b"""{\n "meta": {\n },\n "errors": [\n ],\n "jobs": [
\n {\n "account": null,\n "job_id": 123,\n "name": "main-0",
\n "state": {\n "current": "COMPLETED",\n "reason": "None"},
Expand All @@ -416,6 +457,57 @@ def test_list(self, run: MagicMock) -> None:
self.assertIsNotNone(apps)
self.assertEqual(apps, expected_apps)

@patch("subprocess.run")
def test_list_squeue(self, run: MagicMock) -> None:
run.return_value.stdout = b"""{
"jobs": [
{
"job_id": 1234,
"name": "foo",
"job_state": ["FAILED"]
},
{
"job_id": 1235,
"name": "foo",
"job_state": ["FAILED"],
"het_job_id": {
"set": true,
"infinite": false,
"number": 0
}
},
{
"job_id": 1236,
"name": "foo-0",
"job_state": ["RUNNING"],
"het_job_id": {
"set": true,
"infinite": false,
"number": 1236
}
},
{
"job_id": 1237,
"name": "foo-1",
"job_state": ["RUNNING"],
"het_job_id": {
"set": true,
"infinite": false,
"number": 1236
}
}
]
}"""
scheduler = create_scheduler("foo")
expected_apps = [
ListAppResponse(app_id="1234", state=AppState.FAILED, name="foo"),
ListAppResponse(app_id="1235", state=AppState.FAILED, name="foo"),
ListAppResponse(app_id="1236", state=AppState.RUNNING, name="foo-0"),
]
apps = scheduler._list_squeue()
self.assertIsNotNone(apps)
self.assertEqual(apps, expected_apps)

@patch("subprocess.run")
def test_log_iter(self, run: MagicMock) -> None:
scheduler = create_scheduler("foo")
Expand Down
Loading