@@ -223,6 +223,55 @@ def not_exists_or_empty(folder):
223
223
)
224
224
225
225
226
+ def get_cuda_version_from_nvcc ():
227
+ """Get CUDA version from nvcc if available."""
228
+ try :
229
+ result = subprocess .check_output (
230
+ ["nvcc" , "--version" ], stderr = subprocess .STDOUT
231
+ )
232
+ output = result .decode ("utf-8" )
233
+ # Look for version line like "release 12.6"
234
+ for line in output .split ("\n " ):
235
+ if "release" in line .lower ():
236
+ parts = line .split ()
237
+ for i , part in enumerate (parts ):
238
+ if part .lower () == "release" and i + 1 < len (parts ):
239
+ return parts [i + 1 ].rstrip ("," )
240
+
241
+ except :
242
+ return None
243
+
244
+
245
+ def get_cutlass_build_flags ():
246
+ """Determine which CUTLASS kernels to build based on CUDA version.
247
+ SM90a: CUDA 12.6+, SM100a: CUDA 12.8+
248
+ """
249
+ # Try nvcc then torch version
250
+ cuda_version = get_cuda_version_from_nvcc () or torch .version .cuda
251
+
252
+ try :
253
+ if not cuda_version :
254
+ raise ValueError ("No CUDA version found" )
255
+
256
+ major , minor = map (int , cuda_version .split ("." )[:2 ])
257
+ build_sm90a = major > 12 or (major == 12 and minor >= 6 )
258
+ build_sm100a = major > 12 or (major == 12 and minor >= 8 )
259
+
260
+ if build_sm90a :
261
+ print (f"CUDA { cuda_version } : Enabling SM90a CUTLASS kernels" )
262
+ if build_sm100a :
263
+ print (f"CUDA { cuda_version } : Enabling SM100a CUTLASS kernels" )
264
+
265
+ return build_sm90a , build_sm100a
266
+ except :
267
+ # Fallback to architecture flags
268
+ cuda_arch_flags = _get_cuda_arch_flags ()
269
+ return (
270
+ "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags ,
271
+ "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags ,
272
+ )
273
+
274
+
226
275
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
227
276
class TorchAOBuildExt (BuildExtension ):
228
277
def __init__ (self , * args , ** kwargs ) -> None :
@@ -455,9 +504,7 @@ def get_extensions():
455
504
]
456
505
)
457
506
458
- cuda_arch_flags = _get_cuda_arch_flags ()
459
- build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
460
- build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
507
+ build_for_sm90a , build_for_sm100a = get_cutlass_build_flags ()
461
508
# Define sm90a sources
462
509
cutlass_90a_sources = [
463
510
os .path .join (
0 commit comments