Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,33 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)

def _concat_k_nope_k_pe(
self, k_nope: torch.Tensor, k_pe: torch.Tensor
) -> torch.Tensor:
"""
Efficiently concatenate k_nope and k_pe tensors along the last dimension.

This function avoids the performance penalty of torch.cat with expanded
non-contiguous tensors by pre-allocating the output and using direct copies.

Args:
k_nope: Tensor of shape [..., nope_dim]
k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
or [..., pe_dim]

Returns:
Tensor of shape [..., nope_dim + pe_dim]
"""
k = torch.empty(
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
dtype=k_nope.dtype,
device=k_nope.device,
)
# Direct copies with efficient broadcasting
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
return k

def _compute_prefill_context(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -1681,7 +1708,7 @@ def _compute_prefill_context(
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)

attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
Expand Down Expand Up @@ -1785,7 +1812,7 @@ def _context_parallel_compute_prefill_context(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)

attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
Expand Down Expand Up @@ -1834,7 +1861,7 @@ def _forward_prefill(
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)

output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
Expand Down