Skip to content

feat(rocm-support): support mamba2 on rocm #18565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ endif()
#

set(VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
Expand Down Expand Up @@ -285,8 +287,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable(cutlass)

list(APPEND VLLM_EXT_SRC
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/permute_cols.cu"
Expand Down
10 changes: 4 additions & 6 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>

#ifdef USE_ROCM
namespace cub = hipcub;
#endif

#include "static_switch.h"


Expand Down Expand Up @@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);

Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
70 changes: 35 additions & 35 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, float scale) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);

// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);

// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"
Expand Down Expand Up @@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);

// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int pad_slot_id) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation,"
"Tensor? cache_seqlens_,"
"Tensor? conv_state_indices,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation,"
"int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);

#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
Expand Down
23 changes: 18 additions & 5 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.platforms import current_platform


@dataclass
Expand All @@ -23,6 +22,21 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor


def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata)
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
elif current_platform.is_cuda():
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
return (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")


def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
Expand Down Expand Up @@ -78,9 +92,8 @@ def prepare_mamba2_metadata(

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
has_initial_states = \
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
Expand Down
Loading