Skip to content

Conversation

@Pr0Wh1teGivee
Copy link
Contributor

@Pr0Wh1teGivee Pr0Wh1teGivee commented Jun 20, 2025

What this PR does / why we need it?

Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are not None, otherwise fallback to the original one. The replacement will take place automatically when VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1 .

This patch are using npu_top_k_top_p which required torch_npu>=2.5.1.post1.dev20250619

Does this PR introduce any user-facing change?

No

How was this patch tested?

Tested by DeepSeek R1

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jun 20, 2025
k: torch.Tensor,
) -> torch.Tensor:
if p is not None and k is not None:
return torch_npu.npu_top_k_top_p(logits, p, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which torch_npu version supported this call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since rc2 b050

Copy link
Member

@Yikun Yikun Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update commits msg with publich torch_npu version.

such as: https://mirrors.huaweicloud.com/ascend/repos/pypi/torch-npu/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API introduced by torch_npu-2.5.1.post1.dev20250619

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@Yikun
Copy link
Member

Yikun commented Jun 23, 2025

  1. Please do a rebase because we update the torch version morning, better to do a e2e test.
  2. Please add a ut to make sure npu_top_k_top_p is called and get expected results in : https://github.com/vllm-project/vllm-ascend/tree/main/tests/ut

logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm....So, the order of _apply_top_k_top_p was wrong.

I suggested to keep same order with upstream https://github.com/vllm-project/vllm/blob/9a3b88328f7e434cac35b90ee463de6689f9a833/vllm/model_executor/layers/sampler.py#L398

Please change L98

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@Pr0Wh1teGivee
Copy link
Contributor Author

  1. Please do a rebase because we update the torch version morning, better to do a e2e test.
  2. Please add a ut to make sure npu_top_k_top_p is called and get expected results in : https://github.com/vllm-project/vllm-ascend/tree/main/tests/ut

fixed

@Pr0Wh1teGivee Pr0Wh1teGivee changed the title use fused ops npu_top_k_top_p [Perf] Use fused ops npu_top_k_top_p Jun 24, 2025
mock_npu_op.assert_called_once_with(logits, p, k)


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Jun 24, 2025
@codecov
Copy link

codecov bot commented Jun 24, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 27.55%. Comparing base (c30ddb8) to head (d88eb37).
⚠️ Report is 550 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1308      +/-   ##
==========================================
+ Coverage   27.39%   27.55%   +0.16%     
==========================================
  Files          56       57       +1     
  Lines        6191     6238      +47     
==========================================
+ Hits         1696     1719      +23     
- Misses       4495     4519      +24     
Flag Coverage Δ
unittests 27.55% <100.00%> (+0.16%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@Yikun Yikun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please address comments in a new PR.

Comment on lines +15 to +16
import vllm_ascend.patch.worker.patch_common.patch_sampler
importlib.reload(vllm_ascend.patch.worker.patch_common.patch_sampler)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer this [1] to patch test and make TestTopKTopPSamplerOptimize based on TestBase

[1] https://github.com/vllm-project/vllm-ascend/pull/1386/files#diff-eae86bf6e7a9a6ef5d079fa80ca12e946ecff4e587e5b66d3761f2cc7f6bb9c5R4

return logits


def _apply_top_k_top_p(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trend to rename this _apply_top_k_top_p to apply_top_k_top_p to avoid confused to keep same with:
https://github.com/vllm-project/vllm/blob/c53fec1fcb27aca9475e55c2d1e74c532f5f0364/vllm/v1/sample/ops/topk_topp_sampler.py#L165

@Yikun Yikun merged commit 2fda604 into vllm-project:main Jun 25, 2025
24 checks passed
@Yikun Yikun added this to the v0.9.1 milestone Jun 26, 2025
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Jun 30, 2025
### What this PR does / why we need it?
Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are
not None, otherwise fallback to the original one. The replacement will
take place automatically when `VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1` .

This patch are using `npu_top_k_top_p` which required
torch_npu>=2.5.1.post1.dev20250619

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Tested by DeepSeek R1 and UT passed

Signed-off-by: Pr0Wh1teGivee <[email protected]>
@wangxiyuan
Copy link
Collaborator

apply_min_p has been removed from vllm main vllm-project/vllm@48fb076 we'll cleanup this patch code once vllm 0.9.2 is comming. Please update to the newest code if you still want this feature.

wangxiyuan pushed a commit that referenced this pull request Aug 1, 2025
### What this PR does / why we need it?
Fixed 310p failure when using the sampler feature.
The root cause is: torch_npu.npu_top_k_top_p uses the operator
aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support
310P.
First PR that has the issue is #1308.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@207b750

Signed-off-by: leo-pony <[email protected]>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
Fixed 310p failure when using the sampler feature.
The root cause is: torch_npu.npu_top_k_top_p uses the operator
aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support
310P.
First PR that has the issue is vllm-project#1308.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@207b750

Signed-off-by: leo-pony <[email protected]>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Oct 16, 2025
### What this PR does / why we need it?
Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are
not None, otherwise fallback to the original one. The replacement will
take place automatically when `VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1` .

This patch are using `npu_top_k_top_p` which required
torch_npu>=2.5.1.post1.dev20250619

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Tested by DeepSeek R1 and UT passed

Signed-off-by: Pr0Wh1teGivee <[email protected]>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Use fused ops torch_npu.npu_top_k_top_p(logits, p, k) when p and k are
not None, otherwise fallback to the original one. The replacement will
take place automatically when `VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE=1` .

This patch are using `npu_top_k_top_p` which required
torch_npu>=2.5.1.post1.dev20250619

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Tested by DeepSeek R1 and UT passed

Signed-off-by: Pr0Wh1teGivee <[email protected]>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Fixed 310p failure when using the sampler feature.
The root cause is: torch_npu.npu_top_k_top_p uses the operator
aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support
310P.
First PR that has the issue is vllm-project#1308.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@207b750

Signed-off-by: leo-pony <[email protected]>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
### What this PR does / why we need it?
Fixed 310p failure when using the sampler feature.
The root cause is: torch_npu.npu_top_k_top_p uses the operator
aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support
310P.
First PR that has the issue is vllm-project#1308.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@207b750

Signed-off-by: leo-pony <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants