diff --git a/setup.py b/setup.py index f59917162f..0eb85f8362 100644 --- a/setup.py +++ b/setup.py @@ -364,6 +364,7 @@ def get_extensions(): use_cutlass = False cutlass_90a_sources = None + cutlass_100a_sources = None if use_cuda and not IS_ROCM and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") @@ -395,6 +396,8 @@ def get_extensions(): cuda_arch_flags = _get_cuda_arch_flags() build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags + build_for_sm100 = "-gencode=arch=compute_100,code=sm_100" in cuda_arch_flags + build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags if build_for_sm90 and not build_for_sm90a: cutlass_90a_sources = [ os.path.join( @@ -418,6 +421,17 @@ def get_extensions(): ) ) sources = [s for s in sources if s not in cutlass_90a_sources] + + if build_for_sm100 and not build_for_sm100a: + cutlass_100a_sources = [ + os.path.join( + extensions_cuda_dir, + "mx_kernels", + "mx_fp_cutlass_kernels.cu", + ), + ] + sources = [s for s in sources if s not in cutlass_100a_sources] + else: # Remove CUTLASS-based kernels from the sources list. An # assumption is that these files will have "cutlass" in its @@ -448,7 +462,7 @@ def get_extensions(): ) ext_modules.append( extension( - "torchao._C", + "torchao._C_cutlass_90a", cutlass_90a_sources, py_limited_api=True, extra_compile_args=cutlass_90a_extra_compile_args, @@ -456,6 +470,21 @@ def get_extensions(): ) ) + if cutlass_100a_sources is not None and len(cutlass_100a_sources) > 0: + cutlass_100a_extra_compile_args = copy.deepcopy(extra_compile_args) + cutlass_100a_extra_compile_args["nvcc"].extend( + cuda_arch_flags + ["-gencode=arch=compute_100a,code=sm_100a"] + ) + ext_modules.append( + extension( + "torchao._C_cutlass_100a", + cutlass_100a_sources, + py_limited_api=True, + extra_compile_args=cutlass_100a_extra_compile_args, + extra_link_args=extra_link_args, + ) + ) + # Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1": build_options = BuildOptions() diff --git a/torchao/__init__.py b/torchao/__init__.py index 7cc447d5a7..9c86f78441 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -25,8 +25,8 @@ so_files = list(Path(__file__).parent.glob("_C*.so")) if len(so_files) > 0: - assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" - torch.ops.load_library(str(so_files[0])) + for file in so_files: + torch.ops.load_library(str(file)) from . import ops # The following library contains CPU kernels from torchao/experimental