4
4
import json
5
5
import os
6
6
import sys
7
- from typing import List
7
+ from typing import Any , Dict , List
8
8
9
9
# currently we don't support python 3.13t due to tensorrt does not support 3.13t
10
10
disabled_python_versions : List [str ] = ["3.13t" ]
17
17
sbsa_container_image : str = "quay.io/pypa/manylinux_2_34_aarch64"
18
18
19
19
20
+ def validate_matrix (matrix_dict : Dict [str , Any ]) -> None :
21
+ """Validate the structure of the input matrix."""
22
+ if not isinstance (matrix_dict , dict ):
23
+ raise ValueError ("Matrix must be a dictionary" )
24
+ if "include" not in matrix_dict :
25
+ raise ValueError ("Matrix must contain 'include' key" )
26
+ if not isinstance (matrix_dict ["include" ], list ):
27
+ raise ValueError ("Matrix 'include' must be a list" )
28
+
29
+
30
+ def filter_matrix_item (
31
+ item : Dict [str , Any ], is_jetpack : bool , limit_pr_builds : bool , is_nightly : bool
32
+ ) -> bool :
33
+ """Filter a single matrix item based on the build type and requirements."""
34
+ if item ["python_version" ] in disabled_python_versions :
35
+ # Skipping disabled Python version
36
+ return False
37
+
38
+ if is_jetpack :
39
+ if limit_pr_builds :
40
+ # pr build,matrix passed from test-infra is cu128, python 3.9, change to cu126, python 3.10
41
+ item ["desired_cuda" ] = "cu126"
42
+ item ["python_version" ] = "3.10"
43
+ item ["container_image" ] = jetpack_container_image
44
+ return True
45
+ elif is_nightly :
46
+ # nightly build, matrix passed from test-infra is cu128, all python versions, change to cu126, python 3.10
47
+ if item ["python_version" ] in jetpack_python_versions :
48
+ item ["desired_cuda" ] = "cu126"
49
+ item ["container_image" ] = jetpack_container_image
50
+ return True
51
+ return False
52
+ else :
53
+ if (
54
+ item ["python_version" ] in jetpack_python_versions
55
+ and item ["desired_cuda" ] in jetpack_cuda_versions
56
+ ):
57
+ item ["container_image" ] = jetpack_container_image
58
+ return True
59
+ return False
60
+ else :
61
+ if item ["gpu_arch_type" ] == "cuda-aarch64" :
62
+ # pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
63
+ # however, TensorRT requires glibc2.31 on aarch64 platform
64
+ # TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
65
+ item ["container_image" ] = sbsa_container_image
66
+ return True
67
+ return True
68
+
69
+
20
70
def main (args : list [str ]) -> None :
21
71
parser = argparse .ArgumentParser ()
22
72
parser .add_argument (
@@ -42,41 +92,39 @@ def main(args: list[str]) -> None:
42
92
default = os .getenv ("LIMIT_PR_BUILDS" , "false" ),
43
93
)
44
94
95
+ parser .add_argument (
96
+ "--is-nightly" ,
97
+ help = "If it is a nightly build" ,
98
+ type = str ,
99
+ choices = ["true" , "false" ],
100
+ default = os .getenv ("LIMIT_PR_BUILDS" , "false" ),
101
+ )
102
+
45
103
options = parser .parse_args (args )
46
104
if options .matrix == "" :
47
- raise Exception ("--matrix needs to be provided" )
105
+ raise ValueError ("--matrix needs to be provided" )
106
+
107
+ try :
108
+ matrix_dict = json .loads (options .matrix )
109
+ validate_matrix (matrix_dict )
110
+ except json .JSONDecodeError as e :
111
+ raise ValueError (f"Invalid JSON in matrix: { e } " )
112
+ except ValueError as e :
113
+ raise ValueError (f"Invalid matrix structure: { e } " )
48
114
49
- matrix_dict = json .loads (options .matrix )
50
115
includes = matrix_dict ["include" ]
51
116
filtered_includes = []
117
+
52
118
for item in includes :
53
- if item ["python_version" ] in disabled_python_versions :
54
- continue
55
- if options .jetpack == "true" :
56
- if options .limit_pr_builds == "true" :
57
- # limit pr build, matrix passed in from test-infra is cu128, python 3.9, change to cu126, python 3.10
58
- item ["desired_cuda" ] = "cu126"
59
- item ["python_version" ] = "3.10"
60
- item ["container_image" ] = jetpack_container_image
61
- filtered_includes .append (item )
62
- else :
63
- if (
64
- item ["python_version" ] in jetpack_python_versions
65
- and item ["desired_cuda" ] in jetpack_cuda_versions
66
- ):
67
- item ["container_image" ] = jetpack_container_image
68
- filtered_includes .append (item )
69
- else :
70
- if item ["gpu_arch_type" ] == "cuda-aarch64" :
71
- # pytorch image:pytorch/manylinuxaarch64-builder:cuda12.8 comes with glibc2.28
72
- # however, TensorRT requires glibc2.31 on aarch64 platform
73
- # TODO: in future, if pytorch supports aarch64 with glibc2.31, we should switch to use the pytorch image
74
- item ["container_image" ] = sbsa_container_image
75
- filtered_includes .append (item )
76
- else :
77
- filtered_includes .append (item )
78
- filtered_matrix_dict = {}
79
- filtered_matrix_dict ["include" ] = filtered_includes
119
+ if filter_matrix_item (
120
+ item ,
121
+ options .jetpack == "true" ,
122
+ options .limit_pr_builds == "true" ,
123
+ options .is_nightly == "true" ,
124
+ ):
125
+ filtered_includes .append (item )
126
+
127
+ filtered_matrix_dict = {"include" : filtered_includes }
80
128
print (json .dumps (filtered_matrix_dict ))
81
129
82
130
0 commit comments