Skip to content

Commit 8ccf325

Browse files
Publish jetson wheel to pytorch nightly index (#3550)
1 parent de175e7 commit 8ccf325

File tree

3 files changed

+83
-33
lines changed

3 files changed

+83
-33
lines changed

.github/scripts/filter-matrix.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import os
66
import sys
7-
from typing import List
7+
from typing import Any, Dict, List
88

99
# currently we don't support python 3.13t due to tensorrt does not support 3.13t
1010
disabled_python_versions: List[str] = ["3.13t"]
@@ -17,6 +17,56 @@
1717
sbsa_container_image: str = "quay.io/pypa/manylinux_2_34_aarch64"
1818

1919

20+
def validate_matrix(matrix_dict: Dict[str, Any]) -> None:
21+
"""Validate the structure of the input matrix."""
22+
if not isinstance(matrix_dict, dict):
23+
raise ValueError("Matrix must be a dictionary")
24+
if "include" not in matrix_dict:
25+
raise ValueError("Matrix must contain 'include' key")
26+
if not isinstance(matrix_dict["include"], list):
27+
raise ValueError("Matrix 'include' must be a list")
28+
29+
30+
def filter_matrix_item(
31+
item: Dict[str, Any], is_jetpack: bool, limit_pr_builds: bool, is_nightly: bool
32+
) -> bool:
33+
"""Filter a single matrix item based on the build type and requirements."""
34+
if item["python_version"] in disabled_python_versions:
35+
# Skipping disabled Python version
36+
return False
37+
38+
if is_jetpack:
39+
if limit_pr_builds:
40+
# pr build,matrix passed from test-infra is cu128, python 3.9, change to cu126, python 3.10
41+
item["desired_cuda"] = "cu126"
42+
item["python_version"] = "3.10"
43+
item["container_image"] = jetpack_container_image
44+
return True
45+
elif is_nightly:
46+
# nightly build, matrix passed from test-infra is cu128, all python versions, change to cu126, python 3.10
47+
if item["python_version"] in jetpack_python_versions:
48+
item["desired_cuda"] = "cu126"
49+
item["container_image"] = jetpack_container_image
50+
return True
51+
return False
52+
else:
53+
if (
54+
item["python_version"] in jetpack_python_versions
55+
and item["desired_cuda"] in jetpack_cuda_versions
56+
):
57+
item["container_image"] = jetpack_container_image
58+
return True
59+
return False
60+
else:
61+
if item["gpu_arch_type"] == "cuda-aarch64":
62+
# pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
63+
# however, TensorRT requires glibc2.31 on aarch64 platform
64+
# TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
65+
item["container_image"] = sbsa_container_image
66+
return True
67+
return True
68+
69+
2070
def main(args: list[str]) -> None:
2171
parser = argparse.ArgumentParser()
2272
parser.add_argument(
@@ -42,41 +92,39 @@ def main(args: list[str]) -> None:
4292
default=os.getenv("LIMIT_PR_BUILDS", "false"),
4393
)
4494

95+
parser.add_argument(
96+
"--is-nightly",
97+
help="If it is a nightly build",
98+
type=str,
99+
choices=["true", "false"],
100+
default=os.getenv("LIMIT_PR_BUILDS", "false"),
101+
)
102+
45103
options = parser.parse_args(args)
46104
if options.matrix == "":
47-
raise Exception("--matrix needs to be provided")
105+
raise ValueError("--matrix needs to be provided")
106+
107+
try:
108+
matrix_dict = json.loads(options.matrix)
109+
validate_matrix(matrix_dict)
110+
except json.JSONDecodeError as e:
111+
raise ValueError(f"Invalid JSON in matrix: {e}")
112+
except ValueError as e:
113+
raise ValueError(f"Invalid matrix structure: {e}")
48114

49-
matrix_dict = json.loads(options.matrix)
50115
includes = matrix_dict["include"]
51116
filtered_includes = []
117+
52118
for item in includes:
53-
if item["python_version"] in disabled_python_versions:
54-
continue
55-
if options.jetpack == "true":
56-
if options.limit_pr_builds == "true":
57-
# limit pr build, matrix passed in from test-infra is cu128, python 3.9, change to cu126, python 3.10
58-
item["desired_cuda"] = "cu126"
59-
item["python_version"] = "3.10"
60-
item["container_image"] = jetpack_container_image
61-
filtered_includes.append(item)
62-
else:
63-
if (
64-
item["python_version"] in jetpack_python_versions
65-
and item["desired_cuda"] in jetpack_cuda_versions
66-
):
67-
item["container_image"] = jetpack_container_image
68-
filtered_includes.append(item)
69-
else:
70-
if item["gpu_arch_type"] == "cuda-aarch64":
71-
# pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
72-
# however, TensorRT requires glibc2.31 on aarch64 platform
73-
# TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
74-
item["container_image"] = sbsa_container_image
75-
filtered_includes.append(item)
76-
else:
77-
filtered_includes.append(item)
78-
filtered_matrix_dict = {}
79-
filtered_matrix_dict["include"] = filtered_includes
119+
if filter_matrix_item(
120+
item,
121+
options.jetpack == "true",
122+
options.limit_pr_builds == "true",
123+
options.is_nightly == "true",
124+
):
125+
filtered_includes.append(item)
126+
127+
filtered_matrix_dict = {"include": filtered_includes}
80128
print(json.dumps(filtered_matrix_dict))
81129

82130

.github/workflows/build-test-linux-aarch64-jetpack.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ jobs:
4040
id: filter
4141
env:
4242
LIMIT_PR_BUILDS: ${{ github.event_name == 'pull_request' && !contains( github.event.pull_request.labels.*.name, 'ciflow/binaries/all') }}
43+
NIGHTLY_BUILDS: ${{ inputs.trigger-event == 'push' && startsWith(github.event.ref, 'refs/heads/nightly') }}
4344
run: |
4445
set -eou pipefail
4546
echo "LIMIT_PR_BUILDS=${LIMIT_PR_BUILDS}"
47+
echo "NIGHTLY_BUILDS=${NIGHTLY_BUILDS}"
4648
MATRIX_BLOB=${{ toJSON(needs.generate-matrix.outputs.matrix) }}
47-
MATRIX_BLOB="$(python3 .github/scripts/filter-matrix.py --matrix "${MATRIX_BLOB}" --jetpack true)"
49+
MATRIX_BLOB="$(python3 .github/scripts/filter-matrix.py --matrix "${MATRIX_BLOB}" --jetpack true --limit-pr-builds "${LIMIT_PR_BUILDS}" --is-nightly "${NIGHTLY_BUILDS}")"
4850
echo "${MATRIX_BLOB}"
4951
echo "matrix=${MATRIX_BLOB}" >> "${GITHUB_OUTPUT}"
5052

.github/workflows/build_wheels_linux_aarch64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ jobs:
332332
upload:
333333
needs: build
334334
uses: pytorch/test-infra/.github/workflows/_binary_upload.yml@main
335-
# only upload to pytorch index for non jetpack builds
336-
if: ${{ inputs.is-jetpack == false }}
335+
# for jetpack builds, only upload to pytorch index for nightly builds
336+
if: ${{ inputs.is-jetpack == false || (inputs.trigger-event == 'push' && startsWith(github.event.ref, 'refs/heads/nightly')) }}
337337
with:
338338
repository: ${{ inputs.repository }}
339339
ref: ${{ inputs.ref }}

0 commit comments

Comments
 (0)