Skip to content

Manually specify flags if no arch set #2219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added CUDA_ARCH_NOTES.md
Empty file.
31 changes: 31 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,37 @@ def get_extensions():
use_cuda = torch.version.cuda and (CUDA_HOME is not None or ROCM_HOME is not None)
extension = CUDAExtension if use_cuda else CppExtension

# =====================================================================================
# CUDA Architecture Settings
# =====================================================================================
# If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
# detect architectures from available GPUs. This can fail when:
# 1. No GPU is visible to PyTorch
# 2. CUDA is available but no device is detected
#
# To resolve this, you can manually set CUDA architecture targets:
# export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
#
# Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
# =====================================================================================
if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os.environ and torch.version.cuda:
# Set to common architectures for CUDA 12.x compatibility
cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"

# Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
cuda_version = torch.version.cuda
if cuda_version and cuda_version.startswith("12.8"):
print("Detected CUDA 12.8 - adding SM10.0 architectures to build list")
cuda_arch_list += ";10.0"

# Add PTX to the last architecture for future compatibility
cuda_arch_list += "+PTX"

os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch_list
print(
f"Setting default TORCH_CUDA_ARCH_LIST={os.environ['TORCH_CUDA_ARCH_LIST']}"
)

extra_link_args = []
extra_compile_args = {
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
Expand Down
Loading