diff --git a/CMakeLists.txt b/CMakeLists.txt index ffb801d6261..067ff7d5322 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,8 @@ endif() # set(VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" @@ -285,8 +287,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC - "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 98daf1a1b8e..f62d08c17c6 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -13,6 +13,10 @@ #include #include +#ifdef USE_ROCM + namespace cub = hipcub; +#endif + #include "static_switch.h" @@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { auto kernel = &causal_conv1d_fwd_kernel; if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. C10_CUDA_CHECK(cudaFuncSetAttribute( (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif } kernel<<>>(params); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index bd0a34119c8..0c9df925bdb 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4eda1aaccc6..371894c56a7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, float scale) -> ()"); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Mamba selective scan kernel - ops.def( - "selective_scan_fwd(Tensor! u, Tensor! delta," - "Tensor! A, Tensor! B, Tensor! C," - "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," - "bool delta_softplus," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "Tensor! ssm_states," - "int pad_slot_id) -> ()"); - ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); - - ops.def( - "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation," - "Tensor? cache_seqlens_," - "Tensor? conv_state_indices," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); - - ops.def( - "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor!? conv_states," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "bool silu_activation," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); - // Compute NVFP4 block quantized tensor. ops.def( "scaled_fp4_quant(Tensor! output, Tensor input," @@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation," + "Tensor? cache_seqlens_," + "Tensor? conv_state_indices," + "int pad_slot_id) -> ()"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor!? conv_states," + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," + "bool silu_activation," + "int pad_slot_id) -> ()"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + #ifndef USE_ROCM // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index e5b88de2fcc..019f634a9ef 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -5,10 +5,9 @@ import torch from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) -from vllm.attention.backends.xformers import XFormersMetadata +from vllm.platforms import current_platform @dataclass @@ -23,6 +22,21 @@ class Mamba2Metadata: chunk_offsets: torch.Tensor +def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: + """Returns the appropriate metadata classes for the current platform.""" + if current_platform.is_rocm(): + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata) + return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) + elif current_platform.is_cuda(): + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + from vllm.attention.backends.xformers import XFormersMetadata + return (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata) + raise ValueError( + f"Unsupported platform for Mamba2: {current_platform.device_type}") + + def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int): @@ -78,9 +92,8 @@ def prepare_mamba2_metadata( # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: - if (isinstance(attn_metadata, - (FlashAttentionMetadata, XFormersMetadata, - PlaceholderAttentionMetadata)) + attn_metadata_instances = get_platform_metadata_classes() + if (isinstance(attn_metadata, attn_metadata_instances) and attn_metadata.context_lens_tensor is not None): has_initial_states = \ attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]