|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import asyncio
|
| 11 | +import concurrent |
11 | 12 | import os
|
12 | 13 | import time
|
13 | 14 | import unittest
|
14 | 15 | from dataclasses import asdict
|
15 |
| -from typing import Dict, List, Mapping, Union |
| 16 | +from typing import Dict, List, Mapping, Tuple, Union |
16 | 17 | from unittest.mock import MagicMock
|
17 | 18 |
|
18 | 19 | import torchx.specs.named_resources_aws as named_resources_aws
|
@@ -299,6 +300,33 @@ async def update(value: str, time_seconds: int) -> str:
|
299 | 300 | self.assertEqual("base", default.image)
|
300 | 301 | self.assertEqual("nentry", default.entrypoint)
|
301 | 302 |
|
| 303 | + def test_concurrent_override_role(self) -> None: |
| 304 | + |
| 305 | + def delay(value: Tuple[str, str], time_seconds: int) -> Tuple[str, str]: |
| 306 | + time.sleep(time_seconds) |
| 307 | + return value |
| 308 | + |
| 309 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: |
| 310 | + launcher_fbpkg_future: concurrent.futures.Future = executor.submit( |
| 311 | + delay, ("value1", "value2"), 2 |
| 312 | + ) |
| 313 | + |
| 314 | + def get_image() -> str: |
| 315 | + concurrent.futures.wait([launcher_fbpkg_future], 3) |
| 316 | + return launcher_fbpkg_future.result()[0] |
| 317 | + |
| 318 | + def get_entrypoint() -> str: |
| 319 | + concurrent.futures.wait([launcher_fbpkg_future], 3) |
| 320 | + return launcher_fbpkg_future.result()[1] |
| 321 | + |
| 322 | + default = Role( |
| 323 | + "foobar", |
| 324 | + "torch", |
| 325 | + overrides={"image": get_image, "entrypoint": get_entrypoint}, |
| 326 | + ) |
| 327 | + self.assertEqual("value1", default.image) |
| 328 | + self.assertEqual("value2", default.entrypoint) |
| 329 | + |
302 | 330 |
|
303 | 331 | class AppHandleTest(unittest.TestCase):
|
304 | 332 | def test_parse_malformed_app_handles(self) -> None:
|
|
0 commit comments