Skip to content

Commit 4b2746d

Browse files
committed
Manually specify flags if no arch set
stack-info: PR: #2219, branch: drisspg/stack/55
1 parent adc78b7 commit 4b2746d

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

CUDA_ARCH_NOTES.md

Whitespace-only changes.

setup.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,37 @@ def get_extensions():
295295
use_cuda = torch.version.cuda and (CUDA_HOME is not None or ROCM_HOME is not None)
296296
extension = CUDAExtension if use_cuda else CppExtension
297297

298+
# =====================================================================================
299+
# CUDA Architecture Settings
300+
# =====================================================================================
301+
# If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
302+
# detect architectures from available GPUs. This can fail when:
303+
# 1. No GPU is visible to PyTorch
304+
# 2. CUDA is available but no device is detected
305+
#
306+
# To resolve this, you can manually set CUDA architecture targets:
307+
# export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
308+
#
309+
# Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
310+
# =====================================================================================
311+
if use_cuda and "TORCH_CUDA_ARCH_LIST" not in os.environ and torch.version.cuda:
312+
# Set to common architectures for CUDA 12.x compatibility
313+
cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
314+
315+
# Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
316+
cuda_version = torch.version.cuda
317+
if cuda_version and cuda_version.startswith("12.8"):
318+
print("Detected CUDA 12.8 - adding SM10.0 architectures to build list")
319+
cuda_arch_list += ";10.0"
320+
321+
# Add PTX to the last architecture for future compatibility
322+
cuda_arch_list += "+PTX"
323+
324+
os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch_list
325+
print(
326+
f"Setting default TORCH_CUDA_ARCH_LIST={os.environ['TORCH_CUDA_ARCH_LIST']}"
327+
)
328+
298329
extra_link_args = []
299330
extra_compile_args = {
300331
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],

0 commit comments

Comments
 (0)