diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 8b3b947a2..5d9c15159 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -241,11 +241,14 @@ class RetryPolicy(str, Enum): is not violated using extra hosts as spares. It does not really support elasticity and just uses the delta between num_replicas and min_replicas as spares (EXPERIMENTAL). + 4. ROLE: Restarts the role when any error occurs in that role. This does not + restart the whole job. """ REPLICA = "REPLICA" APPLICATION = "APPLICATION" HOT_SPARE = "HOT_SPARE" + ROLE = "ROLE" class MountType(str, Enum): diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 3ff5caea1..54d6c7876 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -266,6 +266,17 @@ def test_build_role(self) -> None: self.assertEqual(5, trainer.max_retries) self.assertEqual(RetryPolicy.REPLICA, trainer.retry_policy) + def test_retry_policies(self) -> None: + self.assertCountEqual( + set(RetryPolicy), # pyre-ignore[6]: Enum isn't iterable + { + RetryPolicy.APPLICATION, + RetryPolicy.REPLICA, + RetryPolicy.ROLE, + RetryPolicy.HOT_SPARE, + }, + ) + class AppHandleTest(unittest.TestCase): def test_parse_malformed_app_handles(self) -> None: