Skip to content

Commit 919480d

Browse files
committed
Fix Bug in MX Builds
stack-info: PR: #2284, branch: drisspg/stack/62
1 parent e51ffd9 commit 919480d

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

setup.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,55 @@ def not_exists_or_empty(folder):
223223
)
224224

225225

226+
def get_cuda_version_from_nvcc():
227+
"""Get CUDA version from nvcc if available."""
228+
try:
229+
result = subprocess.check_output(
230+
["nvcc", "--version"], stderr=subprocess.STDOUT
231+
)
232+
output = result.decode("utf-8")
233+
# Look for version line like "release 12.6"
234+
for line in output.split("\n"):
235+
if "release" in line.lower():
236+
parts = line.split()
237+
for i, part in enumerate(parts):
238+
if part.lower() == "release" and i + 1 < len(parts):
239+
return parts[i + 1].rstrip(",")
240+
241+
except:
242+
return None
243+
244+
245+
def get_cutlass_build_flags():
246+
"""Determine which CUTLASS kernels to build based on CUDA version.
247+
SM90a: CUDA 12.6+, SM100a: CUDA 12.8+
248+
"""
249+
# Try nvcc then torch version
250+
cuda_version = get_cuda_version_from_nvcc() or torch.version.cuda
251+
252+
try:
253+
if not cuda_version:
254+
raise ValueError("No CUDA version found")
255+
256+
major, minor = map(int, cuda_version.split(".")[:2])
257+
build_sm90a = major > 12 or (major == 12 and minor >= 6)
258+
build_sm100a = major > 12 or (major == 12 and minor >= 8)
259+
260+
if build_sm90a:
261+
print(f"CUDA {cuda_version}: Enabling SM90a CUTLASS kernels")
262+
if build_sm100a:
263+
print(f"CUDA {cuda_version}: Enabling SM100a CUTLASS kernels")
264+
265+
return build_sm90a, build_sm100a
266+
except:
267+
# Fallback to architecture flags
268+
cuda_arch_flags = _get_cuda_arch_flags()
269+
return (
270+
"-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags,
271+
"-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags,
272+
)
273+
274+
226275
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
227276
class TorchAOBuildExt(BuildExtension):
228277
def __init__(self, *args, **kwargs) -> None:
@@ -455,9 +504,7 @@ def get_extensions():
455504
]
456505
)
457506

458-
cuda_arch_flags = _get_cuda_arch_flags()
459-
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
460-
build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
507+
build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
461508
# Define sm90a sources
462509
cutlass_90a_sources = [
463510
os.path.join(

torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
// This source code is licensed under the BSD 3-Clause license found in the
55
// LICENSE file in the root directory of this source tree.
66

7-
// Ensure this file is only compiled with sm100a architecture
8-
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 1000)
9-
#error "This file must be compiled with compute capability 10.0a or higher (Blackwell architecture)"
10-
#endif
117
#include <torch/library.h>
128

139
#include <ATen/ATen.h>

0 commit comments

Comments
 (0)