-
Notifications
You must be signed in to change notification settings - Fork 213
[0.9.1][Bugfix] fix oom issue in mla and enable mla_pa for deepseek mla decode #1311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[0.9.1][Bugfix] fix oom issue in mla and enable mla_pa for deepseek mla decode #1311
Conversation
I'll cherry-pick this PR back to the main branch after its tested and merged |
@@ -1193,10 +1207,11 @@ def forward( | |||
decode_k_nope, decode_k_pe, | |||
kv_cache, attn_metadata) | |||
else: | |||
combined_cache = torch.cat([kv_cache[0], kv_cache[1]], dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the oom issue mainly caused by this combined_cache
, is removing this enough to this pr? If so, maybe we could add npu_multi_head_latent_attention
when the torch-npu is available officially
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove is not enough in this case cause, as you can see, the paged_attention_mla
can only receive the concatenated cache as its input parameter, withoutnpu_multi_head_latent_attention
the cat seems inevitable.
dbo test seems will oom in ci Processed prompts: 0%| | 0/41 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... =================================== FAILURES =================================== |
79bda3a
to
9a7edd8
Compare
6c3b6ff
to
d06dd4f
Compare
b4439dd
to
71ec2b4
Compare
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: liziyu <[email protected]>
Signed-off-by: ganyi <[email protected]>
71ec2b4
to
6ae2af2
Compare
What this PR does / why we need it?
After the disaggregated PD merged, the kv cache on deepseek will become two piece of independent buffer for kv transfer or computation. However, the current kernel, namely
paged_attention_mla
can only accept k_cache as a single parameter, this make us have to concat these two piece of kv cache together before the attention thus incurs a memory peak inside the attention in eager mode. In this PR we introduce atorch_npu.atb.npu_multi_head_latent_attention
for mla decode path, which will be used as default path for both eager mode and aclgraph after the related torch_npu is public available. Since its still a restrict package, we addVLLM_ASCEND_MLA_PA
to control its usage. This flag will be removed in the future.Does this PR introduce any user-facing change?
Yes, add a new flag named
VLLM_ASCEND_MLA_PA
, but it will be removed eventually after the newest torch_npu is released.How was this patch tested?