diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 010817e79a93..8d96a92b4c2b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools +import importlib from collections.abc import Callable import torch import vllm.envs as envs +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata _FP8_DTYPE = current_platform.fp8_dtype() @@ -54,6 +59,188 @@ def wrapper(*args, **kwargs): return wrapper +@triton.jit +def _indexer_k_quant_and_cache_kernel( + k_ptr, # [num_tokens, head_dim] + kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + slot_mapping_ptr, # [num_tokens] + kv_cache_scale_stride, + kv_cache_value_stride, + block_size, + num_tokens, + head_dim: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, + IS_FNUZ: tl.constexpr, + USE_UE8M0: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, head_dim) + tile_offset = ( + offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + tile_store_offset = tile_offset + # for idx in tl.range(tid, num_tokens, n_program): + src_ptr = k_ptr + tid * head_dim + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + tile_block_id = block_offset // BLOCK_TILE_SIZE + tile_block_offset = block_offset % BLOCK_TILE_SIZE + val = tl.load(src_ptr + offset) + amax = tl.max(val.abs(), axis=-1).to(tl.float32) + if IS_FNUZ: + scale = tl.maximum(1e-4, amax) / 224.0 + else: + scale = tl.maximum(1e-4, amax) / 448.0 + + if USE_UE8M0: + scale = tl.exp2(tl.ceil(tl.log2(scale))) + + fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) + dst_ptr = ( + kv_cache_ptr + + block_id * kv_cache_value_stride + + tile_block_id * BLOCK_TILE_SIZE * head_dim + + tile_block_offset * HEAD_TILE_SIZE + ) + tl.store(dst_ptr + tile_store_offset, fp8_val) + dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset + tl.store(dst_scale_ptr, scale) + + +def indexer_k_quant_and_cache_triton( + k: torch.Tensor, + kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + slot_mapping: torch.Tensor, + quant_block_size, + scale_fmt, + block_tile_size=16, + head_tile_size=16, +): + num_blocks = kv_cache.shape[0] + head_dim = k.shape[-1] + num_tokens = slot_mapping.shape[0] + block_size = kv_cache.shape[1] + # In real layout, we store the first portion as kv cache value + # and second portion as kv cache scale + kv_cache = kv_cache.view(num_blocks, -1) + kv_cache_value = kv_cache[:, : block_size * head_dim] + kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) + head_tile_size = head_tile_size // kv_cache.element_size() + grid = (num_tokens,) + _indexer_k_quant_and_cache_kernel[grid]( + k, + kv_cache_value, + kv_cache_scale, + slot_mapping, + kv_cache_scale.stride(0), + kv_cache_value.stride(0), + block_size, + num_tokens, + head_dim, + block_tile_size, + head_tile_size, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, + USE_UE8M0=scale_fmt == "ue8m0", + ) + + +@triton.jit +def _cp_gather_indexer_quant_cache_kernel( + kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B] + kv_cache_scale_ptr, # [n_blks, blk_size] + k_fp8_ptr, # [num_tokens, head_dim] + k_scale_ptr, # [num_tokens] + block_table_ptr, # [batch_size, block_table_stride] + cu_seqlen_ptr, # [batch_size + 1] + token_to_seq_ptr, # [num_tokens] + block_size, + block_table_stride, + kv_cache_stride, + kv_cache_scale_stride, + HEAD_DIM: tl.constexpr, + BLOCK_TILE_SIZE: tl.constexpr, + HEAD_TILE_SIZE: tl.constexpr, +): + tid = tl.program_id(0) + offset = tl.arange(0, HEAD_DIM) + batch_id = tl.load(token_to_seq_ptr + tid) + batch_start = tl.load(cu_seqlen_ptr + batch_id) + batch_end = tl.load(cu_seqlen_ptr + batch_id + 1) + batch_offset = tid - batch_start + if tid >= batch_end: + return + block_table_id = batch_offset // block_size + block_offset = batch_offset % block_size + block_table_offset = batch_id * block_table_stride + block_table_id + block_id = tl.load(block_table_ptr + block_table_offset) + tiled_block_id = block_offset // BLOCK_TILE_SIZE + tiled_block_offset = block_offset % BLOCK_TILE_SIZE + src_cache_offset = ( + block_id * kv_cache_stride + + tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE + + tiled_block_offset * HEAD_TILE_SIZE + ) + src_scale_offset = block_id * kv_cache_scale_stride + block_offset + dst_offset = tid * HEAD_DIM + src_scale_ptr = kv_cache_scale_ptr + src_scale_offset + src_cache_ptr = kv_cache_ptr + src_cache_offset + dst_k_ptr = k_fp8_ptr + dst_offset + scale_val = tl.load(src_scale_ptr) + tl.store(k_scale_ptr + tid, scale_val) + tiled_src_offset = ( + offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE + + offset % HEAD_TILE_SIZE + ) + val = tl.load(src_cache_ptr + tiled_src_offset) + tl.store(dst_k_ptr + offset, val) + + +def cp_gather_indexer_k_quant_cache_triton( + k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4] + k_fp8: torch.Tensor, + k_fp8_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seqlen: torch.Tensor, + token_to_seq: torch.Tensor, + block_tile_size: int = 16, + head_tile_size: int = 16, +): + num_tokens = k_fp8.size(0) + block_size = k_cache.size(1) + block_table_stride = block_table.stride(0) + head_dim = k_fp8.shape[-1] + num_blocks = k_cache.shape[0] + # we assume the kv cache already been split to 2 portion + k_cache = k_cache.view(num_blocks, -1) + fp8_dtype = current_platform.fp8_dtype() + k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype) + k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32) + grid = (num_tokens,) + k_fp8_scale = k_fp8_scale.view(torch.float32) + _cp_gather_indexer_quant_cache_kernel[grid]( + k_cache_value, + k_cache_scale, + k_fp8, + k_fp8_scale, + block_table, + cu_seqlen, + token_to_seq, + block_size, + block_table_stride, + k_cache_value.stride(0), + k_cache_scale.stride(0), + head_dim, + block_tile_size, + head_tile_size, + ) + + def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -637,6 +824,395 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake( return x_fp8, out_bs +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 +def fp8_paged_mqa_logits_torch( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, _, dim = q.size() + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] + scale = scale.contiguous().view(torch.float) + q = q.float() + kv_cache = kv_cache.view(fp8_dtype).float() * scale + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + return logits + + +def rocm_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + + if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + + out_logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device="cuda", + dtype=torch.float32, + ) + deepgemm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + out_logits, + context_lens, + block_tables, + max_model_len, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, + ) + return out_logits + else: + return fp8_paged_mqa_logits_torch( + q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len + ) + + +# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 +def fp8_mqa_logits_torch( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + kv, scale = kv + seq_len_kv = kv.shape[0] + k = kv.to(torch.bfloat16) + q = q.to(torch.bfloat16) + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +def rocm_fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + + # TODO(ganyi): Temporarily workaround, will remove the module check and reference + # path after aiter merge this kernel into main + @functools.lru_cache + def has_mqa_logits_module(): + return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None + + if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + + kv, scale = kv + return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) + else: + return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def rocm_aiter_sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return rocm_aiter_sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + indexer_k_quant_and_cache_triton( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) + cp_gather_indexer_k_quant_cache_triton( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + chunk.token_to_seq, + ) + + logits = rocm_fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = rocm_fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -862,6 +1438,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_sparse_attn_indexer", + op_func=rocm_aiter_sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_aiter_sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py deleted file mode 100644 index 080e92ecc940..000000000000 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ /dev/null @@ -1,210 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib -from functools import lru_cache - -import torch - -from vllm._aiter_ops import rocm_aiter_ops -from vllm.logger import init_logger -from vllm.platforms import current_platform - -logger = init_logger(__name__) - - -# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 -def fp8_mqa_logits_torch( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - kv, scale = kv - seq_len_kv = kv.shape[0] - k = kv.to(torch.bfloat16) - q = q.to(torch.bfloat16) - - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] - ) - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k).float() * scale - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - return logits - - -def rocm_fp8_mqa_logits( - q: torch.Tensor, - kv: tuple[torch.Tensor, torch.Tensor], - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, -) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. - - Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. - weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. - - Returns: - Logits tensor of shape [M, N], dtype `torch.float32`. - """ - - # TODO(ganyi): Temporarily workaround, will remove the module check and reference - # path after aiter merge this kernel into main - @lru_cache - def has_mqa_logits_module(): - return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None - - if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): - from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits - - kv, scale = kv - return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) - else: - return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) - - -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 -def fp8_paged_mqa_logits_torch( - q: torch.Tensor, - kv_cache: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - max_model_len: int, -): - from vllm.utils.math_utils import cdiv - - fp8_dtype = current_platform.fp8_dtype() - batch_size, next_n, _, dim = q.size() - kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] - scale = scale.contiguous().view(torch.float) - q = q.float() - kv_cache = kv_cache.view(fp8_dtype).float() * scale - num_block, block_size, _, dim = kv_cache.size() - logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device=q.device, - dtype=torch.float32, - ) - context_lens = context_lens.tolist() - for i in range(batch_size): - context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") - weight_slice = ( - weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() - ) - for block_rk in range(cdiv(context_len, block_size)): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" - ) - mask = (k_offsets[None, :] < context_len) & ( - k_offsets[None, :] <= q_offsets[:, None] - ) - s = torch.where( - mask[None, :, :], - (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), - float("-inf"), - ) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[ - i * next_n : (i + 1) * next_n, - block_rk * block_size : (block_rk + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) - return logits - - -def rocm_fp8_paged_mqa_logits( - q_fp8: torch.Tensor, - kv_cache_fp8: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - schedule_metadata: torch.Tensor, - max_model_len: int, -) -> torch.Tensor: - """Compute FP8 MQA logits using paged KV-cache. - - Args: - q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape - [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last - 4 bytes per (block,pos) store the `float` dequant scale. - weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. - context_lens: Tensor of shape [B], dtype int32; effective context length - for each batch element. - block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical - block indices to physical blocks in the paged cache. - schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; - used to distribute work across SMs. - max_model_len: Maximum sequence length used to size the logits output. - - Returns: - Logits tensor of shape [B * next_n, max_model_len], dtype - `torch.float32`. - """ - - if rocm_aiter_ops.is_enabled(): - from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 - - batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), - float("-inf"), - device="cuda", - dtype=torch.float32, - ) - deepgemm_fp8_paged_mqa_logits_stage1( - q_fp8, - kv_cache_fp8, - weights, - out_qk, - context_lens, - block_tables, - max_model_len, - ) - return out_qk.sum(dim=0) - else: - return fp8_paged_mqa_logits_torch( - q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len - ) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 3b6cb8a34360..2fa470978be7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -611,6 +611,7 @@ class CompilationConfig: "vllm::gdn_attention_core", "vllm::kda_attention", "vllm::sparse_attn_indexer", + "vllm::rocm_aiter_sparse_attn_indexer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py new file mode 100644 index 000000000000..b05da89460af --- /dev/null +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -0,0 +1,314 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom Sparse Attention Indexer layers.""" + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerMetadata, +) + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + print("in cuda path, which is wrong!", flush=True) + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=fp8_dtype, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, + ) + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + logits = fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold + # (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +@CustomOp.register("sparse_attn_indexer") +class SparseAttnIndexer(CustomOp): + """Sparse Attention Indexer Custom Op Layer. This layer is extracted as a + separate custom op since it involves heavy custom kernels like `mqa_logits`, + `paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires + specific memory layout or implementation for different hardware backends to + achieve optimal performance. + + For now, the default native path will use CUDA backend path. Other platform + may requires add the corresponding Custom Op name `sparse_attn_indexer` to + `custom_ops` in `CompilationConfig` to enable the platform specific path. + """ + + def __init__( + self, + k_cache, + quant_block_size: int, + scale_fmt: str, + topk_tokens: int, + head_dim: int, + max_model_len: int, + max_total_seq_len: int, + topk_indices_buffer: torch.Tensor, + ): + super().__init__() + self.k_cache = k_cache + self.quant_block_size = quant_block_size + self.scale_fmt = scale_fmt + self.topk_tokens = topk_tokens + self.head_dim = head_dim + self.max_model_len = max_model_len + self.max_total_seq_len = max_total_seq_len + self.topk_indices_buffer = topk_indices_buffer + + def forwrad_native( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return self.forward_cuda(hidden_states, q_fp8, k, weights) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, + self.k_cache.layer_prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + + def forward_hip( + self, + hidden_states: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + ): + if rocm_aiter_ops.is_enabled(): + return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( + hidden_states, + self.k_cache.prefix, + self.k_cache.kv_cache[0], + q_fp8, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + else: + raise RuntimeError( + "Sparse attention indexer ROCm custom op requires ROCm " + "Aiter ops to be enabled." + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0b6513789aea..0f0c8dcd906c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -35,7 +35,6 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( @@ -45,7 +44,6 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -65,6 +63,7 @@ per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -76,11 +75,8 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, - DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec @@ -93,10 +89,8 @@ maybe_prefix, ) -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops +if current_platform.is_cuda_alike() or current_platform.is_xpu(): + pass logger = init_logger(__name__) @@ -600,206 +594,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -def sparse_attn_indexer( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - # careful! this will be None in dummy run - attn_metadata = get_forward_context().attn_metadata - fp8_dtype = current_platform.fp8_dtype() - # assert isinstance(attn_metadata, dict) - if not isinstance(attn_metadata, dict): - return sparse_attn_indexer_fake( - hidden_states, - k_cache_prefix, - kv_cache, - q_fp8, - k, - weights, - quant_block_size, - scale_fmt, - topk_tokens, - head_dim, - max_model_len, - total_seq_lens, - topk_indices_buffer, - ) - attn_metadata = attn_metadata[k_cache_prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) - - topk_indices_buffer[: hidden_states.shape[0]] = -1 - if has_prefill: - prefill_metadata = attn_metadata.prefill - for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=k.device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8, - ) - ops.cp_gather_indexer_k_quant_cache( - kv_cache, - k_fp8, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - ) - fp8_mqa_logits_func = fp8_mqa_logits - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits - - fp8_mqa_logits_func = rocm_fp8_mqa_logits - logits = fp8_mqa_logits_func( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32)), - weights[chunk.token_start : chunk.token_end], - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - - if has_decode: - decode_metadata = attn_metadata.decode - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) - decode_lens = decode_metadata.decode_lens - if decode_metadata.requires_padding: - # pad in edge case where we have short chunked prefill length < - # decode_threshold since we unstrictly split - # prefill and decode by decode_threshold - # (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) - else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] - ) - # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] - assert batch_size == decode_metadata.seq_lens.shape[0] - num_padded_tokens = batch_size * next_n - fp8_paged_mqa_logits_func = fp8_paged_mqa_logits - if current_platform.is_rocm(): - from vllm.attention.ops.rocm_aiter_mla_sparse import ( - rocm_fp8_paged_mqa_logits, - ) - - fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits - logits = fp8_paged_mqa_logits_func( - padded_q_fp8_decode_tokens, - kv_cache, - weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - if decode_metadata.requires_padding: - # if padded, we need to unpack - # the topk indices removing padded tokens - topk_indices = unpack_seq_triton( - topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens, - ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices - ) - - return topk_indices_buffer - - -def sparse_attn_indexer_fake( - hidden_states: torch.Tensor, - k_cache_prefix: str, - kv_cache: torch.Tensor, - q_fp8: torch.Tensor, - k: torch.Tensor, - weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, - topk_tokens: int, - head_dim: int, - max_model_len: int, - total_seq_lens: int, - topk_indices_buffer: torch.Tensor | None, -) -> torch.Tensor: - # profile run - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 - ) - fp8_dtype = current_platform.fp8_dtype() - _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() - _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() - return topk_indices_buffer - - -direct_register_custom_op( - op_name="sparse_attn_indexer", - op_func=sparse_attn_indexer, - mutates_args=["topk_indices_buffer"], - fake_impl=sparse_attn_indexer_fake, - dispatch_key=current_platform.dispatch_key, -) - - class Indexer(nn.Module): def __init__( self, @@ -860,6 +654,16 @@ def __init__( from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + self.indexer_op = SparseAttnIndexer( + self.k_cache, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb @@ -897,21 +701,7 @@ def forward( ) weights = weights.squeeze(-1) - return torch.ops.vllm.sparse_attn_indexer( - hidden_states, - self.k_cache.prefix, - self.k_cache.kv_cache[0], - q_fp8, - k, - weights, - self.quant_block_size, - self.scale_fmt, - self.topk_tokens, - self.head_dim, - self.max_model_len, - self.max_total_seq_len, - self.topk_indices_buffer, - ) + return self.indexer_op(hidden_states, q_fp8, k, weights) class DeepseekV2MLAAttention(nn.Module): diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 876114c2d33a..1455f625f118 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -226,9 +226,6 @@ def get_attn_backend_cls( raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) - assert block_size == 1, ( - "Sparse MLA backend on ROCm only supports block size 1 for now." - ) logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() @@ -380,6 +377,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config parallel_config = vllm_config.parallel_config + model_config = vllm_config.model_config + hf_config = model_config.hf_config is_eager_execution = compilation_config == CUDAGraphMode.NONE use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled() @@ -432,6 +431,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: compilation_config.custom_ops.append("+quant_fp8") + # Default dispatch to rocm's sparse_attn_indexer implementation + if hf_config is not None and hasattr(hf_config, "index_topk"): + print("add sparse attn indexer to rocm custom ops", flush=True) + compilation_config.custom_ops.append("+sparse_attn_indexer") + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 77f1ba00d5b0..23ea33351068 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -11,7 +11,6 @@ ) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -24,9 +23,7 @@ class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -62,6 +59,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: cu_seqlen_ks: torch.Tensor cu_seqlen_ke: torch.Tensor cu_seq_lens: torch.Tensor + token_to_seq: torch.Tensor total_seq_lens: int token_start: int token_end: int @@ -258,6 +256,10 @@ def build_one_prefill_chunk( token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32) + token_to_seq = torch.repeat_interleave( + seq_idx, seq_lens_cpu[reqs_start:reqs_end] + ).to(self.device) assert total_seq_lens <= self.max_prefill_buffer_size cu_seq_lens = ( torch.cat( @@ -273,6 +275,7 @@ def build_one_prefill_chunk( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, cu_seq_lens=cu_seq_lens, + token_to_seq=token_to_seq, total_seq_lens=total_seq_lens, block_table=block_table[reqs_start:reqs_end], token_start=token_start, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index c0e7f0e380b9..066f5d5089a9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -17,6 +17,7 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.common import ( MLACommonBaseImpl, ) @@ -35,6 +36,48 @@ logger = init_logger(__name__) +@triton.jit +def fetch_id_to_ragged_kernel( + in_tensor_ptr, # [num_seq, topk] + cumsum_ptr, # [num_seq + 1] + out_tensor_ptr, # [max_num_seq * topk] + in_tensor_ptr_stride, + TOPK: tl.constexpr, + TOKEN_NUM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + block_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + token_start = tl.load(cumsum_ptr + seq_id) + token_end = tl.load(cumsum_ptr + seq_id + 1) + token_num = token_end - token_start + row_offset = block_id * BLOCK_SIZE + if row_offset >= token_num: + return + in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset + in_tensor_mask = (row_offset + offset) < TOPK + in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask) + out_tensor_offset = token_start + row_offset + offset + out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask + tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask) + + +def fetch_id_to_ragged_triton( + in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk +): + num_tokens = in_tensor.size(0) + block_size = 64 + num_block_per_row = triton.cdiv(topk, block_size) + grid = ( + num_tokens, + num_block_per_row, + ) + fetch_id_to_ragged_kernel[grid]( + in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size + ) + + class ROCMAiterMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @@ -85,6 +128,13 @@ class ROCMAiterMLASparseMetadata: block_table: torch.Tensor req_id_per_token: torch.Tensor + + qo_indptr: torch.Tensor + paged_kv_last_page_len: torch.Tensor + paged_kv_indices: torch.Tensor + paged_kv_indptr: torch.Tensor + paged_kv_indptr_rest: torch.Tensor + block_size: int = 1 topk_tokens: int = 2048 @@ -93,7 +143,9 @@ class ROCMAiterMLASparseMetadata: class ROCMAiterMLASparseMetadataBuilder( AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] ): - cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + _cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) def __init__( self, @@ -106,6 +158,8 @@ def __init__( self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) @@ -126,6 +180,23 @@ def __init__( dtype=torch.int32, device=device, ) + self.qo_indptr = torch.arange( + 0, max_num_batched_tokens + 1, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.ones( + max_num_seqs, dtype=torch.int32, device=device + ) + + # These two needs to be calculated in runtime, + # but we still needs to prepare the buffer + self.paged_kv_indices = torch.zeros( + [max_num_batched_tokens * self.topk_tokens], + dtype=torch.int32, + device=device, + ) + self.paged_kv_indptr = torch.zeros( + [max_num_seqs + 1], dtype=torch.int32, device=device + ) def build( self, @@ -144,7 +215,15 @@ def build( self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + qo_indptr = self.qo_indptr[: num_tokens + 1] + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] + paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] + paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -157,6 +236,11 @@ def build( req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, + qo_indptr=qo_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -228,20 +312,39 @@ def __init__( def _forward_bf16_kv( self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, + q: torch.Tensor, # [sq, heads, d_qk] + kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] + topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1] + output = torch.empty( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=q.dtype, + device=q.device, + ) + seq_len = (topk_indices != -1).sum(dim=-1) + torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) + attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) + fetch_id_to_ragged_triton( + topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.topk_tokens, + ) + + rocm_aiter_ops.mla_decode_fwd( + q, + kv_c_and_k_pe_cache, + output, + self.scale, + attn_metadata.qo_indptr, + 1, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len, ) - topk_indices = topk_indices.view(num_tokens, 1, -1) - output = reference_mla_sparse_prefill( - q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512 - )[0] return output[:, : self.num_heads, :] def forward(