Skip to content

[0.9.1][Bugfix] fix oom issue in mla and enable mla_pa for deepseek mla decode #1311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions examples/disaggregate_prefill_v1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ Execution Sequence

* Run prefill server P1 on first node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export GLOO_SOCKET_IFNAME="eth0"
export HCCL_IF_IP=172.19.32.175 # node ip
export GLOO_SOCKET_IFNAME="eth0" # network card name
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
Expand Down Expand Up @@ -71,15 +70,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \

* Run prefill server P2 on second node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export HCCL_IF_IP=172.19.241.49
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
Expand Down Expand Up @@ -113,15 +111,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \

* Run decode server d1 on third node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export HCCL_IF_IP=172.19.123.51
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
Expand Down Expand Up @@ -154,15 +151,14 @@ vllm serve /data01/deepseek_r1_w8a8_zhw \

* Run decode server d2 on last node
```shell
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'`
export HCCL_IF_IP=172.19.190.36
export GLOO_SOCKET_IFNAME="eth0"
export TP_SOCKET_IFNAME="eth0"
export HCCL_SOCKET_IFNAME="eth0"
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_VERSION=0.9.1
vllm serve /data01/deepseek_r1_w8a8_zhw \
--host 0.0.0.0 \
--port 20002 \
Expand Down
1 change: 0 additions & 1 deletion examples/disaggregate_prefill_v1/gen_ranktable.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ while [[ $# -gt 0 ]]; do
case "$1" in
--ips)
shift
# 收集所有后续参数直到遇到下一个选项或结束
while [[ $# -gt 0 && ! "$1" == --* ]]; do
IPs+=("$1")
shift
Expand Down
59 changes: 37 additions & 22 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
from vllm_ascend.attention.attention_v1 import AscendAttentionState
Expand Down Expand Up @@ -933,18 +934,12 @@ def _forward_decode(
q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None

q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
num_tokens = q_nope.size(0)
if self.running_in_graph:
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
Expand Down Expand Up @@ -1003,16 +998,35 @@ def _forward_decode(
actual_seq_lengths_kv=decode_meta.seq_lens_list,
)
else:
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
# public available
assert len(kv_c_and_k_pe_cache) > 1
if envs.VLLM_ASCEND_MLA_PA:
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_pe, kv_c_and_k_pe_cache[0],
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
self.num_kv_heads)
else:
q = torch.cat([q_nope, q_pe], dim=-1)
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
k_cache = torch.cat(
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
torch_npu._npu_paged_attention_mla(
query=q,
key_cache=k_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.decode.
block_table, # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore
mla_vheadsize=self.kv_lora_rank,
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
Expand Down Expand Up @@ -1193,10 +1207,11 @@ def forward(
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
else:
combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1)
Copy link
Collaborator

@MengqingCao MengqingCao Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the oom issue mainly caused by this combined_cache, is removing this enough to this pr? If so, maybe we could add npu_multi_head_latent_attention when the torch-npu is available officially

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove is not enough in this case cause, as you can see, the paged_attention_mla can only receive the concatenated cache as its input parameter, withoutnpu_multi_head_latent_attention the cat seems inevitable.

output_decode = self._forward_decode(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
combined_cache, attn_metadata)
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,
decode_k_nope,
decode_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@
# rpc communication listening port, which will be used to receive the agent metadata from the
# remote worker.
"VLLM_LLMDD_RPC_PORT":
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557))
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)),
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_ASCEND_MLA_PA":
lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0))
}

# end-env-vars-definition
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def forward(
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer is not None:
if self.kv_consumer:
is_prefill = False
enable_force_load_balance = False

Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def fused_experts_with_mc2(
if log2phy:
topk_ids = log2phy[topk_ids]
global_bs = 0
moe_expert_num = len(expert_map) + global_redundant_expert_num
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
# hidden_states = hidden_states.bfloat16()
kwargs_mc2 = {
"x": hidden_states,
Expand Down