From 8221cdb8228150367f0cf7eaf38ef0df4c12f4aa Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 22 May 2025 11:54:33 -0700 Subject: [PATCH] slurm: support clusters without sacct --- torchx/cli/cmd_list.py | 7 +- torchx/schedulers/api.py | 1 + torchx/schedulers/slurm_scheduler.py | 86 +++++++++++++++++ .../schedulers/test/slurm_scheduler_test.py | 94 ++++++++++++++++++- 4 files changed, 185 insertions(+), 3 deletions(-) diff --git a/torchx/cli/cmd_list.py b/torchx/cli/cmd_list.py index 6918b4a7e..afa42b4b1 100644 --- a/torchx/cli/cmd_list.py +++ b/torchx/cli/cmd_list.py @@ -21,6 +21,7 @@ HANDLE_HEADER = "APP HANDLE" STATUS_HEADER = "APP STATUS" +NAME_HEADER = "APP NAME" class CmdList(SubCommand): @@ -39,5 +40,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None: def run(self, args: argparse.Namespace) -> None: with get_runner() as runner: apps = runner.list(args.scheduler) - apps_data = [[app.app_handle, str(app.state)] for app in apps] - print(tabulate(apps_data, headers=[HANDLE_HEADER, STATUS_HEADER])) + apps_data = [[app.app_handle, app.name, str(app.state)] for app in apps] + print( + tabulate(apps_data, headers=[HANDLE_HEADER, NAME_HEADER, STATUS_HEADER]) + ) diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 80397c95a..359390a87 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -86,6 +86,7 @@ class ListAppResponse: app_id: str state: AppState app_handle: str = "" + name: str = "" # Implementing __hash__() makes ListAppResponse hashable which makes # it easier to check if a ListAppResponse object exists in a list of diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index fb9c76982..b0b066761 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -482,6 +482,12 @@ def _cancel_existing(self, app_id: str) -> None: subprocess.run(["scancel", app_id], check=True) def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + try: + return self._describe_sacct(app_id) + except subprocess.CalledProcessError: + return self._describe_squeue(app_id) + + def _describe_sacct(self, app_id: str) -> Optional[DescribeAppResponse]: p = subprocess.run( ["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True ) @@ -534,6 +540,48 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]: msg=msg, ) + def _describe_squeue(self, app_id: str) -> Optional[DescribeAppResponse]: + p = subprocess.run( + ["squeue", "--json", "-j", app_id], stdout=subprocess.PIPE, check=True + ) + output_json = json.loads(p.stdout.decode("utf-8")) + + roles = {} + roles_statuses = {} + msg = "" + app_state = AppState.UNKNOWN + for job in output_json["jobs"]: + state = job["job_state"][0] + msg = state + state_enum = SLURM_STATES.get(state) + assert ( + state_enum + ), f"failed to translate slurm state {state} to torchx state" + app_state = state_enum + + role, _, replica_id = job["name"].rpartition("-") + if not replica_id or not role: + # name should always have at least 3 parts but sometimes sacct + # is slow to update + continue + if role not in roles: + roles[role] = Role(name=role, num_replicas=0, image="") + roles_statuses[role] = RoleStatus(role, []) + roles[role].num_replicas += 1 + roles_statuses[role].replicas.append( + ReplicaStatus( + id=int(replica_id), role=role, state=app_state, hostname="" + ), + ) + + return DescribeAppResponse( + app_id=app_id, + roles=list(roles.values()), + roles_statuses=list(roles_statuses.values()), + state=app_state, + msg=msg, + ) + def log_iter( self, app_id: str, @@ -574,6 +622,12 @@ def log_iter( return iterator def list(self) -> List[ListAppResponse]: + try: + return self._list_sacct() + except subprocess.CalledProcessError: + return self._list_squeue() + + def _list_sacct(self) -> List[ListAppResponse]: # By default sacct only returns accounting information of jobs launched on the current day # To return all jobs launched, set starttime to one second past unix epoch time # Starttime will be modified when listing jobs by timeframe is supported @@ -590,6 +644,38 @@ def list(self) -> List[ListAppResponse]: for job in output_json["jobs"] ] + def _list_squeue(self) -> List[ListAppResponse]: + # if sacct isn't configured on the cluster, fallback to squeue which + # only has currently running jobs + p = subprocess.run( + ["squeue", "--json"], + stdout=subprocess.PIPE, + check=True, + ) + output_json = json.loads(p.stdout.decode("utf-8")) + + out = [] + for job in output_json["jobs"]: + job_id = job["job_id"] + + het_job_id = job.get("het_job_id") + if ( + het_job_id + and het_job_id["set"] + and het_job_id["number"] != job_id + and het_job_id["number"] > 0 + ): + continue + + out.append( + ListAppResponse( + app_id=str(job["job_id"]), + state=SLURM_STATES[job["job_state"][0]], + name=job["name"], + ) + ) + return out + def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler: return SlurmScheduler( diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index a2f47daed..971faa249 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -399,7 +399,48 @@ def test_describe_running(self, run: MagicMock) -> None: self.assertEqual(out.state, specs.AppState.RUNNING) @patch("subprocess.run") - def test_list(self, run: MagicMock) -> None: + def test_describe_squeue(self, run: MagicMock) -> None: + run.return_value.stdout = b"""{ + "jobs": [ + { + "job_id": 1236, + "name": "foo-0", + "job_state": ["RUNNING"], + "het_job_id": { + "set": true, + "infinite": false, + "number": 1236 + } + }, + { + "job_id": 1237, + "name": "foo-1", + "job_state": ["RUNNING"], + "het_job_id": { + "set": true, + "infinite": false, + "number": 1236 + } + } + ] +}""" + + scheduler = create_scheduler("foo") + out = scheduler._describe_squeue("54") + + self.assertEqual(run.call_count, 1) + self.assertEqual( + run.call_args, + call(["squeue", "--json", "-j", "54"], stdout=subprocess.PIPE, check=True), + ) + + self.assertIsNotNone(out) + self.assertEqual(out.app_id, "54") + self.assertEqual(out.msg, "RUNNING") + self.assertEqual(out.state, specs.AppState.RUNNING) + + @patch("subprocess.run") + def test_list_sacct(self, run: MagicMock) -> None: run.return_value.stdout = b"""{\n "meta": {\n },\n "errors": [\n ],\n "jobs": [ \n {\n "account": null,\n "job_id": 123,\n "name": "main-0", \n "state": {\n "current": "COMPLETED",\n "reason": "None"}, @@ -416,6 +457,57 @@ def test_list(self, run: MagicMock) -> None: self.assertIsNotNone(apps) self.assertEqual(apps, expected_apps) + @patch("subprocess.run") + def test_list_squeue(self, run: MagicMock) -> None: + run.return_value.stdout = b"""{ + "jobs": [ + { + "job_id": 1234, + "name": "foo", + "job_state": ["FAILED"] + }, + { + "job_id": 1235, + "name": "foo", + "job_state": ["FAILED"], + "het_job_id": { + "set": true, + "infinite": false, + "number": 0 + } + }, + { + "job_id": 1236, + "name": "foo-0", + "job_state": ["RUNNING"], + "het_job_id": { + "set": true, + "infinite": false, + "number": 1236 + } + }, + { + "job_id": 1237, + "name": "foo-1", + "job_state": ["RUNNING"], + "het_job_id": { + "set": true, + "infinite": false, + "number": 1236 + } + } + ] +}""" + scheduler = create_scheduler("foo") + expected_apps = [ + ListAppResponse(app_id="1234", state=AppState.FAILED, name="foo"), + ListAppResponse(app_id="1235", state=AppState.FAILED, name="foo"), + ListAppResponse(app_id="1236", state=AppState.RUNNING, name="foo-0"), + ] + apps = scheduler._list_squeue() + self.assertIsNotNone(apps) + self.assertEqual(apps, expected_apps) + @patch("subprocess.run") def test_log_iter(self, run: MagicMock) -> None: scheduler = create_scheduler("foo")