|
13 | 13 | UnquantizedLinearMethod)
|
14 | 14 | from vllm.utils import cdiv, round_down
|
15 | 15 |
|
16 |
| -from vllm_ascend import envs |
17 | 16 | from vllm_ascend.ascend_config import get_ascend_config
|
18 | 17 | from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
19 | 18 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
@@ -998,35 +997,12 @@ def _forward_decode(
|
998 | 997 | actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
999 | 998 | )
|
1000 | 999 | else:
|
1001 |
| - # The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will |
1002 |
| - # be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become |
1003 |
| - # public available |
1004 | 1000 | assert len(kv_c_and_k_pe_cache) > 1
|
1005 |
| - if envs.VLLM_ASCEND_MLA_PA: |
1006 |
| - attn_output = torch_npu.atb.npu_multi_head_latent_attention( |
1007 |
| - q_nope, q_pe, kv_c_and_k_pe_cache[0], |
1008 |
| - kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, |
1009 |
| - attn_metadata.decode.seq_lens, self.num_heads, self.scale, |
1010 |
| - self.num_kv_heads) |
1011 |
| - else: |
1012 |
| - q = torch.cat([q_nope, q_pe], dim=-1) |
1013 |
| - attn_output = torch.empty( |
1014 |
| - [num_tokens, self.num_heads, self.kv_lora_rank], |
1015 |
| - dtype=q.dtype, |
1016 |
| - device=q.device) |
1017 |
| - k_cache = torch.cat( |
1018 |
| - [kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) |
1019 |
| - torch_npu._npu_paged_attention_mla( |
1020 |
| - query=q, |
1021 |
| - key_cache=k_cache, |
1022 |
| - num_kv_heads=self.num_kv_heads, |
1023 |
| - num_heads=self.num_heads, |
1024 |
| - scale_value=self.scale, |
1025 |
| - block_table=attn_metadata.decode. |
1026 |
| - block_table, # type:ignore |
1027 |
| - context_lens=attn_metadata.decode.seq_lens, # type:ignore |
1028 |
| - mla_vheadsize=self.kv_lora_rank, |
1029 |
| - out=attn_output) |
| 1001 | + attn_output = torch_npu.atb.npu_multi_head_latent_attention( |
| 1002 | + q_nope, q_pe, kv_c_and_k_pe_cache[0], |
| 1003 | + kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table, |
| 1004 | + attn_metadata.decode.seq_lens, self.num_heads, self.scale, |
| 1005 | + self.num_kv_heads) |
1030 | 1006 | current_ms_metadata = get_multistream_comm_context()
|
1031 | 1007 | if current_ms_metadata is None:
|
1032 | 1008 | return self._v_up_proj_and_o_proj(attn_output)
|
|
0 commit comments