|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | + |
| 10 | +""" |
| 11 | +Kubernetes integration tests. |
| 12 | +""" |
| 13 | + |
| 14 | +from component_integration_tests import build_and_push_image |
| 15 | + |
| 16 | +from integ_test_utils import getenv_asserts, MissingEnvError |
| 17 | +from torchx.components.dist import ddp as dist_ddp |
| 18 | +from torchx.runner import get_runner |
| 19 | +from torchx.specs import _named_resource_factories, AppState, Resource |
| 20 | +from torchx.util.types import none_throws |
| 21 | + |
| 22 | +GiB: int = 1024 |
| 23 | + |
| 24 | + |
| 25 | +def register_gpu_resource() -> None: |
| 26 | + res = Resource( |
| 27 | + cpu=2, |
| 28 | + gpu=1, |
| 29 | + memMB=8 * GiB, |
| 30 | + ) |
| 31 | + print(f"Registering resource: {res}") |
| 32 | + _named_resource_factories["GPU_X1"] = lambda: res |
| 33 | + |
| 34 | + |
| 35 | +def run_job() -> None: |
| 36 | + register_gpu_resource() |
| 37 | + build = build_and_push_image(container_repo=getenv_asserts("CONTAINER_REPO")) |
| 38 | + image = build.torchx_image |
| 39 | + runner = get_runner() |
| 40 | + train_app = dist_ddp( |
| 41 | + m="torchx.examples.apps.compute_world_size.main", |
| 42 | + name="ddp-trainer", |
| 43 | + image=image, |
| 44 | + cpu=1, |
| 45 | + j="2x2", |
| 46 | + max_retries=3, |
| 47 | + env={ |
| 48 | + "LOGLEVEL": "INFO", |
| 49 | + }, |
| 50 | + ) |
| 51 | + cfg = { |
| 52 | + "namespace": "torchx-dev", |
| 53 | + "queue": "default", |
| 54 | + } |
| 55 | + app_handle = runner.run(train_app, "kubernetes", cfg) |
| 56 | + print("Start waiting for app to finish") |
| 57 | + runner.wait(app_handle) |
| 58 | + final_status = runner.status(app_handle) |
| 59 | + print(f"Final status: {final_status}") |
| 60 | + if none_throws(final_status).state != AppState.SUCCEEDED: |
| 61 | + raise Exception(f"Dist app failed with status: {final_status}") |
| 62 | + |
| 63 | + |
| 64 | +def main() -> None: |
| 65 | + try: |
| 66 | + run_job() |
| 67 | + except MissingEnvError: |
| 68 | + print("Skip running tests, executed only docker buid step") |
| 69 | + |
| 70 | + |
| 71 | +if __name__ == "__main__": |
| 72 | + main() |
0 commit comments