@@ -295,6 +295,37 @@ def get_extensions():
295
295
use_cuda = torch .version .cuda and (CUDA_HOME is not None or ROCM_HOME is not None )
296
296
extension = CUDAExtension if use_cuda else CppExtension
297
297
298
+ # =====================================================================================
299
+ # CUDA Architecture Settings
300
+ # =====================================================================================
301
+ # If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
302
+ # detect architectures from available GPUs. This can fail when:
303
+ # 1. No GPU is visible to PyTorch
304
+ # 2. CUDA is available but no device is detected
305
+ #
306
+ # To resolve this, you can manually set CUDA architecture targets:
307
+ # export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
308
+ #
309
+ # Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
310
+ # =====================================================================================
311
+ if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os .environ and torch .version .cuda :
312
+ # Set to common architectures for CUDA 12.x compatibility
313
+ cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
314
+
315
+ # Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
316
+ cuda_version = torch .version .cuda
317
+ if cuda_version and cuda_version .startswith ("12.8" ):
318
+ print ("Detected CUDA 12.8 - adding SM10.0 architectures to build list" )
319
+ cuda_arch_list += ";10.0"
320
+
321
+ # Add PTX to the last architecture for future compatibility
322
+ cuda_arch_list += "+PTX"
323
+
324
+ os .environ ["TORCH_CUDA_ARCH_LIST" ] = cuda_arch_list
325
+ print (
326
+ f"Setting default TORCH_CUDA_ARCH_LIST={ os .environ ['TORCH_CUDA_ARCH_LIST' ]} "
327
+ )
328
+
298
329
extra_link_args = []
299
330
extra_compile_args = {
300
331
"cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
0 commit comments