@@ -291,6 +291,37 @@ def get_extensions():
291291 use_cuda = torch .version .cuda and (CUDA_HOME is not None or ROCM_HOME is not None )
292292 extension = CUDAExtension if use_cuda else CppExtension
293293
294+ # =====================================================================================
295+ # CUDA Architecture Settings
296+ # =====================================================================================
297+ # If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
298+ # detect architectures from available GPUs. This can fail when:
299+ # 1. No GPU is visible to PyTorch
300+ # 2. CUDA is available but no device is detected
301+ #
302+ # To resolve this, you can manually set CUDA architecture targets:
303+ # export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
304+ #
305+ # Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
306+ # =====================================================================================
307+ if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os .environ and torch .version .cuda :
308+ # Set to common architectures for CUDA 12.x compatibility
309+ cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
310+
311+ # Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
312+ cuda_version = torch .version .cuda
313+ if cuda_version and cuda_version .startswith ("12.8" ):
314+ print ("Detected CUDA 12.8 - adding SM10.0 architectures to build list" )
315+ cuda_arch_list += ";10.0"
316+
317+ # Add PTX to the last architecture for future compatibility
318+ cuda_arch_list += "+PTX"
319+
320+ os .environ ["TORCH_CUDA_ARCH_LIST" ] = cuda_arch_list
321+ print (
322+ f"Setting default TORCH_CUDA_ARCH_LIST={ os .environ ['TORCH_CUDA_ARCH_LIST' ]} "
323+ )
324+
294325 extra_link_args = []
295326 extra_compile_args = {
296327 "cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
0 commit comments