Skip to content

Commit 710f654

Browse files
authored
Adding Override Option for TorchX Role
Differential Revision: D62591176 Pull Request resolved: #956
1 parent b7fd00b commit 710f654

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

torchx/specs/api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
# pyre-strict
99

10+
import asyncio
1011
import copy
12+
import inspect
1113
import json
1214
import re
1315
import typing
@@ -370,6 +372,24 @@ class Role:
370372
mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
371373
default_factory=list
372374
)
375+
overrides: Dict[str, Any] = field(default_factory=dict)
376+
377+
# pyre-ignore
378+
def __getattribute__(self, attrname: str) -> Any:
379+
if attrname == "overrides":
380+
return super().__getattribute__(attrname)
381+
try:
382+
ov = super().__getattribute__("overrides")
383+
except AttributeError:
384+
ov = {}
385+
if attrname in ov:
386+
if inspect.isawaitable(ov[attrname]):
387+
result = asyncio.get_event_loop().run_until_complete(ov[attrname])
388+
else:
389+
result = ov[attrname]()
390+
setattr(self, attrname, result)
391+
del ov[attrname]
392+
return super().__getattribute__(attrname)
373393

374394
def pre_proc(
375395
self,

torchx/specs/test/api_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import asyncio
1011
import os
1112
import time
1213
import unittest
@@ -276,6 +277,28 @@ def test_retry_policies(self) -> None:
276277
},
277278
)
278279

280+
def test_override_role(self) -> None:
281+
default = Role(
282+
"foobar",
283+
"torch",
284+
overrides={"image": lambda: "base", "entrypoint": lambda: "nentry"},
285+
)
286+
self.assertEqual("base", default.image)
287+
self.assertEqual("nentry", default.entrypoint)
288+
289+
def test_async_override_role(self) -> None:
290+
async def update(value: str, time_seconds: int) -> str:
291+
await asyncio.sleep(time_seconds)
292+
return value
293+
294+
default = Role(
295+
"foobar",
296+
"torch",
297+
overrides={"image": update("base", 1), "entrypoint": update("nentry", 2)},
298+
)
299+
self.assertEqual("base", default.image)
300+
self.assertEqual("nentry", default.entrypoint)
301+
279302

280303
class AppHandleTest(unittest.TestCase):
281304
def test_parse_malformed_app_handles(self) -> None:

0 commit comments

Comments
 (0)