Skip to content

Commit 8ebe98a

Browse files
committed
Fix issue and conflict
Signed-off-by: depeng1994 <[email protected]>
2 parents 3d68df3 + afc4c0c commit 8ebe98a

File tree

18 files changed

+342
-140
lines changed

18 files changed

+342
-140
lines changed

.github/workflows/vllm_ascend_test_long_term.yaml

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,19 @@ jobs:
4141
strategy:
4242
max-parallel: 2
4343
matrix:
44+
os: [linux-arm64-npu-1, linux-arm64-npu-4]
4445
vllm_version: [main, v0.9.0]
46+
concurrency:
47+
group: >
48+
${{
49+
matrix.os == 'linux-arm64-npu-4'
50+
&& github.event.pull_request.number
51+
&& format('pr-{0}-limit-npu-4-long-term', github.event.pull_request.number)
52+
|| format('job-{0}-{1}-{2}-long-term', matrix.os, matrix.vllm_version, github.event.pull_request.number)
53+
}}
54+
cancel-in-progress: false
4555
name: vLLM Ascend long term test
46-
runs-on: linux-arm64-npu-1
56+
runs-on: ${{ matrix.os }}
4757
container:
4858
# TODO(yikun): Remove m.daocloud.io prefix when infra proxy ready
4959
image: m.daocloud.io/quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10
@@ -92,8 +102,13 @@ jobs:
92102
93103
- name: Run vllm-project/vllm-ascend long term test
94104
run: |
95-
# spec decode test
96-
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
97-
VLLM_USE_MODELSCOPE=true pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
98-
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
99-
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
105+
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
106+
# spec decode test
107+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
108+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
109+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/spec_decode/e2e/test_mtp_correctness.py # it needs a clean process
110+
pytest -sv tests/long_term/spec_decode --ignore=tests/long_term/spec_decode/e2e/test_mtp_correctness.py --ignore=tests/long_term/spec_decode/e2e/test_v1_spec_decode.py --ignore=tests/long_term/spec_decode/e2e/test_v1_mtp_correctness.py
111+
pytest -sv tests/long_term/test_accuracy.py
112+
else
113+
VLLM_USE_MODELSCOPE=True pytest -sv tests/long_term/test_deepseek_v2_lite_tp2_accuracy.py
114+
fi

docs/source/developer_guide/evaluation/profile_execute_duration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ The execution duration of each stage (including pre/post-processing, model forwa
55
**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.**
66

77
## Usage
8-
* Use the environment variable `VLLM_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
8+
* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
99
* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration.
1010
* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages.
1111

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,4 @@ def prompt_template(request):
354354

355355
@pytest.fixture(scope="session")
356356
def ilama_lora_files():
357-
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
357+
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/blob/main/tests/entrypoints/llm/test_accuracy.py
18+
#
19+
20+
import gc
21+
import multiprocessing
22+
from multiprocessing import Queue
23+
24+
import lm_eval
25+
import pytest
26+
import torch
27+
28+
# pre-trained model path on Hugging Face.
29+
MODELS = ["deepseek-ai/DeepSeek-V2-Lite"]
30+
# Math reasoning benchmark (Grade School Math 8K).
31+
TASK = "gsm8k"
32+
# Answer validation requiring format consistency.
33+
FILTER = "exact_match,strict-match"
34+
# 3% relative tolerance for numerical accuracy.
35+
RTOL = 0.03
36+
# Baseline accuracy after VLLM optimization.
37+
# FIXME: fix the accuracy issue
38+
EXPECTED_VALUE = 0.000758150113722517
39+
40+
41+
def run_test(model_name, queue, more_args=None):
42+
model_args = f"pretrained={model_name},max_model_len=4096,trust_remote_code=True,tensor_parallel_size=4"
43+
if more_args is not None:
44+
model_args = f"{model_args},{more_args}"
45+
results = lm_eval.simple_evaluate(
46+
model="vllm",
47+
model_args=model_args,
48+
tasks=TASK,
49+
batch_size="auto",
50+
)
51+
result = results["results"][TASK][FILTER]
52+
print(100 * "*", "\nThe accuracy test result:", result)
53+
queue.put(result)
54+
del results
55+
torch.npu.empty_cache()
56+
gc.collect()
57+
58+
59+
@pytest.mark.parametrize("model", MODELS)
60+
def test_lm_eval_accuracy(model, monkeypatch: pytest.MonkeyPatch):
61+
with monkeypatch.context():
62+
result_queue: Queue[float] = multiprocessing.Queue()
63+
p = multiprocessing.Process(target=run_test,
64+
args=(
65+
model,
66+
result_queue,
67+
))
68+
p.start()
69+
p.join()
70+
result = result_queue.get()
71+
assert (EXPECTED_VALUE - RTOL < result < EXPECTED_VALUE + RTOL), \
72+
f"Expected: {EXPECTED_VALUE}±{RTOL} | Measured: {result}"

tests/multicard/test_offline_inference_distributed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
"""
2323
import os
2424

25-
import pytest
2625
import vllm # noqa: F401
2726

2827
from tests.conftest import VllmRunner
@@ -47,7 +46,6 @@ def test_models_distributed_QwQ():
4746
vllm_model.generate_greedy(example_prompts, max_tokens)
4847

4948

50-
@pytest.mark.skipif(True, reason="wait for mla issue fixed on v1")
5149
def test_models_distributed_DeepSeek():
5250
example_prompts = [
5351
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",

tests/singlecard/test_profile_execute_duration.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
#
19+
import os
1920
import time
21+
from unittest.mock import patch
2022

2123
import torch
2224
import vllm # noqa: F401
23-
24-
import vllm_ascend.envs as envs
2525
from vllm_ascend.utils import ProfileExecuteDuration
2626

2727

28+
@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
2829
def test_execue_duration_enabled_discrepancy():
2930
a = torch.randn(10000, 10000).npu()
3031
b = torch.randn(10000, 10000).npu()
@@ -33,7 +34,6 @@ def test_execue_duration_enabled_discrepancy():
3334
torch.matmul(a, b)
3435
torch.npu.synchronize()
3536

36-
envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE = True
3737
cpu_start = time.perf_counter()
3838
with ProfileExecuteDuration().capture_async("forward"):
3939
torch.matmul(a, b)
@@ -54,7 +54,6 @@ def test_execue_duration_disabled():
5454
a = torch.randn(100, 100).npu()
5555
b = torch.randn(100, 100).npu()
5656

57-
envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE = False
5857
with ProfileExecuteDuration().capture_async("forward"):
5958
torch.matmul(a, b)
6059
torch.npu.synchronize()

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,7 @@ def __init__(
720720
blocksparse_params: Optional[Dict[str, Any]] = None,
721721
logits_soft_cap: Optional[float] = None,
722722
attn_type: str = AttentionType.DECODER,
723+
kv_sharing_target_layer_name: Optional[str] = None,
723724
use_irope: bool = False,
724725
) -> None:
725726
self.num_heads = num_heads
@@ -961,6 +962,7 @@ def __init__(
961962
blocksparse_params: Optional[Dict[str, Any]] = None,
962963
logits_soft_cap: Optional[float] = None,
963964
attn_type: str = AttentionType.DECODER,
965+
kv_sharing_target_layer_name: Optional[str] = None,
964966
**extra_impl_args,
965967
) -> None:
966968
self.num_heads = num_heads

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
blocksparse_params: Optional[Dict[str, Any]] = None,
187187
logits_soft_cap: Optional[float] = None,
188188
attn_type: str = AttentionType.DECODER,
189+
kv_sharing_target_layer_name: Optional[str] = None,
189190
use_irope: bool = False,
190191
) -> None:
191192
self.num_heads = num_heads

vllm_ascend/attention/mla_v1.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12-
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
13-
LinearBase, RowParallelLinear,
12+
from vllm.model_executor.layers.linear import (LinearBase,
1413
UnquantizedLinearMethod)
15-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
1614

1715
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1816
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
@@ -117,6 +115,8 @@ class AscendMLAMetadata:
117115
# For logging.
118116
num_input_tokens: int = 0 # Number of tokens including padding.
119117

118+
with_prefill_across_dp: bool = False
119+
120120
# The dimension of the attention heads
121121
head_dim: Optional[int] = None
122122
attn_mask: torch.Tensor = None
@@ -260,6 +260,10 @@ def build_dummy(self, num_reqs: int,
260260
PAD_SLOT_ID,
261261
dtype=torch.int32,
262262
device=device)
263+
query_start_loc = torch.full((num_reqs, ),
264+
-1,
265+
dtype=torch.int32,
266+
device=device)
263267
decode_metadata = AscendMLADecodeMetadata(
264268
input_positions=input_positions,
265269
block_table=block_table,
@@ -278,15 +282,21 @@ def build_dummy(self, num_reqs: int,
278282
attn_state=AscendAttentionState.DecodeOnly,
279283
prefill=None,
280284
decode=decode_metadata,
285+
query_start_loc=query_start_loc,
286+
seq_lens=seq_lens,
287+
block_tables=block_table,
281288
)
282289

283-
def build(self,
284-
num_reqs: int,
285-
num_actual_tokens: int,
286-
max_query_len: int,
287-
common_attn_metadata: CommonAttentionMetadata,
288-
common_prefix_len: Optional[int] = None,
289-
graph_pad_size: int = -1) -> AscendMLAMetadata:
290+
def build(
291+
self,
292+
num_reqs: int,
293+
num_actual_tokens: int,
294+
max_query_len: int,
295+
common_attn_metadata: CommonAttentionMetadata,
296+
common_prefix_len: Optional[int] = None,
297+
graph_pad_size: int = -1,
298+
with_prefill_across_dp: bool = False,
299+
) -> AscendMLAMetadata:
290300
assert self._num_decodes + self._num_prefills == num_reqs
291301

292302
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +398,7 @@ def build(self,
388398
query_start_loc=query_start_loc,
389399
block_tables=block_table,
390400
seq_lens=seq_lens,
401+
with_prefill_across_dp=with_prefill_across_dp,
391402
)
392403

393404

@@ -409,20 +420,7 @@ def __init__(
409420
blocksparse_params: Optional[dict[str, Any]],
410421
logits_soft_cap: Optional[float],
411422
attn_type: str,
412-
# MLA Specific Arguments
413-
q_lora_rank: Optional[int],
414-
kv_lora_rank: int,
415-
qk_nope_head_dim: int,
416-
qk_rope_head_dim: int,
417-
qk_head_dim: int,
418-
v_head_dim: int,
419-
rotary_emb: RotaryEmbedding,
420-
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
421-
# attention backend perspective we rely on the layer to pass in the
422-
# correct matrix
423-
q_proj: ColumnParallelLinear,
424-
kv_b_proj: ColumnParallelLinear,
425-
o_proj: RowParallelLinear,
423+
kv_sharing_target_layer_name: Optional[str] = None,
426424
**kwargs,
427425
) -> None:
428426
self.num_heads = num_heads
@@ -431,25 +429,20 @@ def __init__(
431429
self.num_kv_heads = num_kv_heads
432430
self.kv_cache_dtype = kv_cache_dtype
433431

434-
self.q_lora_rank = q_lora_rank
435-
self.kv_lora_rank = kv_lora_rank
436-
self.qk_nope_head_dim = qk_nope_head_dim
437-
self.qk_rope_head_dim = qk_rope_head_dim
438-
self.qk_head_dim = qk_head_dim
439-
self.v_head_dim = v_head_dim
440-
441-
# Hack for V1 for now to avoid torch library overhead (since we are
442-
# already inside an attention custom op), pull out the forward
443-
# method from the rotary embedding and call it directly
444-
# TODO(lucas): we should probably find a cleaner way to do this
445-
self.rotary_emb = rotary_emb
446-
447-
self.q_proj = q_proj
448-
self.kv_b_proj = kv_b_proj
449-
self.o_proj = o_proj
450-
432+
# MLA Args
433+
self.q_lora_rank = kwargs['q_lora_rank']
434+
self.kv_lora_rank = kwargs['kv_lora_rank']
435+
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
436+
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
437+
self.qk_head_dim = kwargs['qk_head_dim']
438+
self.v_head_dim = kwargs['v_head_dim']
439+
self.rotary_emb = kwargs['rotary_emb']
440+
self.q_proj = kwargs['q_proj']
441+
self.kv_b_proj = kwargs['kv_b_proj']
442+
self.o_proj = kwargs['o_proj']
451443
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
452444
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
445+
453446
# Handle the differences between the flash_attn_varlen from flash_attn
454447
# and the one from vllm_flash_attn. The former is used on RoCM and the
455448
# latter has an additional parameter to control FA2 vs FA3
@@ -621,7 +614,7 @@ def exec_kv(
621614
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
622615
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
623616
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
624-
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
617+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
625618
kv,
626619
self.kv_a_layernorm.weight,
627620
cos,
@@ -643,7 +636,7 @@ def rope_single(
643636
B, N, D = x.shape
644637
S = 1
645638
x = x.view(B, N, S, D)
646-
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
639+
x = torch_npu.npu_interleave_rope(x, cos, sin)
647640
return x.view(B, N, D)
648641

649642
def _forward_decode(
@@ -766,6 +759,7 @@ def forward(
766759
sin = sin[attn_metadata.decode.input_positions]
767760
cos = cos[:, None, None, :]
768761
sin = sin[:, None, None, :]
762+
769763
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
770764
decode_k_pe, decode_k_nope = self.exec_kv(
771765
hidden_states_or_kv_c_normed, cos, sin, kv_cache,

vllm_ascend/envs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39-
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
40-
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
4139
"USING_LCCL_COM":
4240
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4341
"SOC_VERSION":
@@ -70,6 +68,9 @@
7068
lambda: os.getenv("VLLM_VERSION", None),
7169
"VLLM_ASCEND_TRACE_RECOMPILES":
7270
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
71+
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
72+
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
73+
),
7374
}
7475

7576
# end-env-vars-definition

0 commit comments

Comments
 (0)