Skip to content

Commit 94ac896

Browse files
authored
Add new test for concurrent futures for TorchX Role
Differential Revision: D63046717 Pull Request resolved: #957
1 parent 710f654 commit 94ac896

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

torchx/specs/test/api_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
# pyre-strict
99

1010
import asyncio
11+
import concurrent
1112
import os
1213
import time
1314
import unittest
1415
from dataclasses import asdict
15-
from typing import Dict, List, Mapping, Union
16+
from typing import Dict, List, Mapping, Tuple, Union
1617
from unittest.mock import MagicMock
1718

1819
import torchx.specs.named_resources_aws as named_resources_aws
@@ -299,6 +300,33 @@ async def update(value: str, time_seconds: int) -> str:
299300
self.assertEqual("base", default.image)
300301
self.assertEqual("nentry", default.entrypoint)
301302

303+
def test_concurrent_override_role(self) -> None:
304+
305+
def delay(value: Tuple[str, str], time_seconds: int) -> Tuple[str, str]:
306+
time.sleep(time_seconds)
307+
return value
308+
309+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
310+
launcher_fbpkg_future: concurrent.futures.Future = executor.submit(
311+
delay, ("value1", "value2"), 2
312+
)
313+
314+
def get_image() -> str:
315+
concurrent.futures.wait([launcher_fbpkg_future], 3)
316+
return launcher_fbpkg_future.result()[0]
317+
318+
def get_entrypoint() -> str:
319+
concurrent.futures.wait([launcher_fbpkg_future], 3)
320+
return launcher_fbpkg_future.result()[1]
321+
322+
default = Role(
323+
"foobar",
324+
"torch",
325+
overrides={"image": get_image, "entrypoint": get_entrypoint},
326+
)
327+
self.assertEqual("value1", default.image)
328+
self.assertEqual("value2", default.entrypoint)
329+
302330

303331
class AppHandleTest(unittest.TestCase):
304332
def test_parse_malformed_app_handles(self) -> None:

0 commit comments

Comments
 (0)