Skip to content

Commit 908a851

Browse files
momo609wangxiaoxin (A)ZhengWG
authored
optimize the funtion of computing topk and topp in sampler. (#970)
### What this PR does / why we need it? Optimize the performance of calculation logic in sampler and deepseekv2. ### Does this PR introduce _any_ user-facing change? Added VLLM_ENABLE_TOPK_OPTIMZE config in sampler ### How was this patch tested? pytest test_sampler.py Signed-off-by: wangxiaoxin (A) <[email protected]> Co-authored-by: wangxiaoxin (A) <[email protected]> Co-authored-by: ZhengWG <[email protected]>
1 parent e1ab6d3 commit 908a851

File tree

9 files changed

+330
-3
lines changed

9 files changed

+330
-3
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
Run `pytest tests/test_offline_inference.py`.
2222
"""
2323
import os
24+
from unittest.mock import patch
2425

2526
import vllm # noqa: F401
27+
from vllm import SamplingParams
2628

2729
from tests.conftest import VllmRunner
2830

@@ -57,3 +59,25 @@ def test_models_distributed_DeepSeek():
5759
distributed_executor_backend="mp",
5860
) as vllm_model:
5961
vllm_model.generate_greedy(example_prompts, max_tokens)
62+
63+
64+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"})
65+
def test_models_distributed_topk() -> None:
66+
example_prompts = [
67+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
68+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
69+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
70+
]
71+
dtype = "half"
72+
sampling_params = SamplingParams(max_tokens=5,
73+
temperature=0.0,
74+
top_k=50,
75+
top_p=0.9)
76+
77+
with VllmRunner(
78+
"deepseek-ai/DeepSeek-V2-Lite",
79+
dtype=dtype,
80+
tensor_parallel_size=4,
81+
distributed_executor_backend="mp",
82+
) as vllm_model:
83+
vllm_model.generate(example_prompts, sampling_params)

tests/singlecard/test_offline_inference.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
Run `pytest tests/test_offline_inference.py`.
2222
"""
2323
import os
24+
from unittest.mock import patch
2425

2526
import pytest
2627
import vllm # noqa: F401
28+
from vllm import SamplingParams
2729
from vllm.assets.image import ImageAsset
2830

2931
import vllm_ascend # noqa: F401
@@ -81,3 +83,24 @@ def test_multimodal(model, prompt_template, vllm_runner):
8183
vllm_model.generate_greedy(prompts=prompts,
8284
images=images,
8385
max_tokens=64)
86+
87+
88+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE": "1"})
89+
def test_models_topk() -> None:
90+
example_prompts = [
91+
"Hello, my name is",
92+
"The president of the United States is",
93+
"The capital of France is",
94+
"The future of AI is",
95+
]
96+
sampling_params = SamplingParams(max_tokens=5,
97+
temperature=0.0,
98+
top_k=50,
99+
top_p=0.9)
100+
101+
with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct",
102+
max_model_len=8192,
103+
dtype="float16",
104+
enforce_eager=True,
105+
gpu_memory_utilization=0.7) as vllm_model:
106+
vllm_model.generate(example_prompts, sampling_params)

tests/singlecard/test_sampler.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
from typing import Optional
20+
21+
import torch
22+
from vllm.v1.sample.sampler import Sampler # noqa: F401
23+
24+
# Set tolerance to 1 for quant ops
25+
DEFAULT_ATOL = 1e-3
26+
DEFAULT_RTOL = 1e-3
27+
28+
29+
def apply_min_p_new(
30+
logits: torch.Tensor,
31+
min_p: torch.Tensor,
32+
) -> torch.Tensor:
33+
"""
34+
Filters logits using adaptive probability thresholding.
35+
"""
36+
if min_p == 0:
37+
return logits
38+
# Convert logits to probability distribution
39+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
40+
# Calculate maximum probabilities per sequence
41+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
42+
# Reshape min_p for broadcasting
43+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
44+
# Identify valid tokens using threshold comparison
45+
# Apply mask using boolean indexing
46+
logits = logits.masked_fill(probability_values < adjusted_min_p,
47+
-float('inf'))
48+
return logits
49+
50+
51+
def apply_top_k_top_p(
52+
logits: torch.Tensor,
53+
k: Optional[torch.Tensor],
54+
p: Optional[torch.Tensor],
55+
) -> torch.Tensor:
56+
"""Apply top-k and top-p masks to the logits.
57+
58+
If a top-p is used, this function will sort the logits tensor,
59+
which can be slow for large batches.
60+
61+
The logits tensor may be updated in-place.
62+
"""
63+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
64+
65+
if k is not None:
66+
# Apply top-k.
67+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
68+
# Get all the top_k values.
69+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
70+
top_k_mask = logits_sort < top_k_mask
71+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
72+
73+
if p is not None:
74+
# Apply top-p.
75+
probs_sort = logits_sort.softmax(dim=-1)
76+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
77+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
78+
# at least one
79+
top_p_mask[:, -1] = False
80+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
81+
82+
# Re-sort the probabilities.
83+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
84+
return logits
85+
86+
87+
def apply_top_k_top_p_new(
88+
logits: torch.Tensor,
89+
k: Optional[torch.Tensor],
90+
p: Optional[torch.Tensor],
91+
) -> torch.Tensor:
92+
batch_size, vocab_size = logits.shape
93+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
94+
95+
# Apply top-k.
96+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
97+
top_k_mask = logits_sort < boundary
98+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
99+
100+
if p is not None:
101+
# Apply top-p.
102+
cutoff = top_k_mask.sum(dim=-1).min()
103+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
104+
probs_sum = probs_sort.cumsum(dim=-1)
105+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
106+
top_p_mask[:, -1] = True
107+
strides = torch.arange(0,
108+
batch_size * vocab_size,
109+
vocab_size,
110+
device=logits.device)
111+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
112+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
113+
logits_flatten = logits.flatten()
114+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
115+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
116+
logits[valid_idx] = valid_logits
117+
return logits.reshape(batch_size, vocab_size)
118+
119+
120+
# test with leading dimension and merge seqlen and batch_size as num_tokens
121+
@torch.inference_mode()
122+
def test_apply_min_p() -> None:
123+
logits = torch.randn((128, 7168)).npu()
124+
min_p = torch.Tensor([0.01]).npu()
125+
logits_new = apply_min_p_new(logits, min_p)
126+
sampler = Sampler()
127+
logits_old = sampler.apply_min_p(logits, min_p)
128+
# Compare the results.
129+
torch.testing.assert_close(logits_new,
130+
logits_old,
131+
atol=DEFAULT_ATOL,
132+
rtol=DEFAULT_RTOL)
133+
134+
135+
# test with leading dimension and merge seqlen and batch_size as num_tokens
136+
@torch.inference_mode()
137+
def test_apply_top_k_top_p() -> None:
138+
logits = torch.randn((128, 7168)).npu()
139+
k = torch.Tensor([-1]).int().npu()
140+
p = torch.Tensor([1]).int().npu()
141+
logits_new = apply_top_k_top_p_new(logits, k, p)
142+
logits_old = apply_top_k_top_p(logits, k, p)
143+
# Compare the results.
144+
torch.testing.assert_close(logits_new,
145+
logits_old,
146+
atol=DEFAULT_ATOL,
147+
rtol=DEFAULT_RTOL)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE":
40+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMZE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,7 @@ def forward(
238238

239239
num_tokens, hidden_size = hidden_states.shape
240240

241-
if self.n_shared_experts is not None:
242-
shared_output = self.shared_experts(hidden_states)
241+
old_hidden_states = hidden_states.clone()
243242

244243
if self.tp_size > 1:
245244
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
@@ -288,6 +287,9 @@ def forward(
288287
if num_padding_tokens > 0:
289288
hidden_states = hidden_states[:-num_padding_tokens]
290289

290+
if self.n_shared_experts is not None:
291+
shared_output = self.shared_experts(old_hidden_states)
292+
291293
if shared_output is not None:
292294
hidden_states = hidden_states + shared_output
293295

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def fused_experts(
363363
num_experts)).to(topk_ids.dtype)
364364

365365
# Sort by local expert IDs
366-
sort_indices = torch.argsort(filtered_experts)
366+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
367367
sorted_token_indices = token_indices[sort_indices]
368368
sorted_weights = filtered_weights[sort_indices]
369369

vllm_ascend/patch/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,30 @@
166166
# Future Plan:
167167
# Revert it when the ascend support triton kernel.
168168
#
169+
# ** File: v1/sample/sampler.py **
170+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171+
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
172+
# Why:
173+
# We need to use the patched `apply_top_k_top_p` in `sample`.
174+
# The mainly reason to overwrite `apply_top_k_top_p` is
175+
# to improve performance.
176+
# How:
177+
# Re-implementation the `apply_top_k_top_p` function by pytorch
178+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
179+
# - https://github.com/vllm-project/vllm-ascend/pull/970
180+
# Future Plan:
181+
# Revert it when the ascend scatter performance improves.
182+
#
183+
# ** File: v1/sample/sampler.py **
184+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s
185+
# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p`
186+
# Why:
187+
# We need to use the patched `apply_min_p` in `sample`.
188+
# The mainly reason to overwrite `apply_min_p` is
189+
# to improve performance.
190+
# How:
191+
# Re-implementation the `apply_min_p` function by pytorch
192+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
193+
# - https://github.com/vllm-project/vllm-ascend/pull/970
194+
# Future Plan:
195+
# Revert it when the ascend indexput performance improves.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
26+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
2627
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
# This file is a part of the vllm-ascend project.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from typing import Optional
20+
21+
import torch
22+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
23+
from vllm.v1.sample.sampler import Sampler
24+
25+
from vllm_ascend import envs
26+
27+
28+
def apply_min_p(
29+
self,
30+
logits: torch.Tensor,
31+
min_p: torch.Tensor,
32+
) -> torch.Tensor:
33+
"""
34+
Filters logits using adaptive probability thresholding.
35+
"""
36+
# Convert logits to probability distribution
37+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
38+
# Calculate maximum probabilities per sequence
39+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
40+
# Reshape min_p for broadcasting
41+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
42+
# Identify valid tokens using threshold comparison
43+
# Apply mask using boolean indexing
44+
logits = logits.masked_fill(probability_values < adjusted_min_p,
45+
-float('inf'))
46+
return logits
47+
48+
49+
def _apply_top_k_top_p(
50+
logits: torch.Tensor,
51+
p: torch.Tensor,
52+
k: torch.Tensor,
53+
) -> torch.Tensor:
54+
probs = logits.softmax(dim=-1)
55+
probs_sort, _ = probs.sort(dim=-1, descending=False)
56+
57+
if k is not None:
58+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
59+
top_k_count = top_k_count.unsqueeze(dim=1)
60+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
61+
62+
# Make sure the no top-k rows are no-op.
63+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
64+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
65+
66+
elements_to_discard = probs < top_k_cutoff
67+
logits.masked_fill_(elements_to_discard, -float("inf"))
68+
69+
if p is not None:
70+
cumprob = torch.cumsum(probs_sort, dim=-1)
71+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
72+
top_p_mask[:, -1] = False # at least one
73+
74+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
75+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
76+
elements_to_discard = probs < top_p_cutoff
77+
logits.masked_fill_(elements_to_discard, -float("inf"))
78+
79+
return logits
80+
81+
82+
def topk_topp_forward_native(
83+
self,
84+
logits: torch.Tensor,
85+
generators: dict[int, torch.Generator],
86+
k: Optional[torch.Tensor],
87+
p: Optional[torch.Tensor],
88+
) -> torch.Tensor:
89+
"""
90+
PyTorch-native implementation of top-k and top-p sampling.
91+
92+
The logits tensor may be updated in-place.
93+
"""
94+
logits = _apply_top_k_top_p(logits, k, p)
95+
probs = logits.softmax(dim=-1, dtype=torch.float32)
96+
return random_sample(probs, generators)
97+
98+
99+
Sampler.apply_min_p = apply_min_p
100+
if envs.VLLM_ASCEND_ENABLE_TOPK_OPTIMZE:
101+
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)