-
-
Notifications
You must be signed in to change notification settings - Fork 12.9k
Feature/silu block quant fusion v1 #32996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/silu block quant fusion v1 #32996
Conversation
…llm-project#27847) Signed-off-by: Monishver Chandrasekaran <monishver@Monishvers-MacBook-Air.local>
…MakeLists.txt Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <mc10322@cuda5.cims.nyu.edu>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
… small batches getting 512 threads/block Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
This reverts commit 8946ae0.
…variable in shared_memory
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused CUDA kernel for SiLU, multiplication, and block-wise FP8 quantization, along with corresponding benchmarks, tests, and integration into the torch.compile fusion passes. The new kernel shows significant performance improvements in the provided benchmarks.
My review has identified a couple of important issues:
- A critical issue in the
torch.compilefusion pass where the pattern for the new fused kernel is hardcoded for a singlegroup_size, which will prevent fusion for other supported sizes. - A high-severity issue in the CUDA kernel implementation regarding a hardcoded shared memory size, which makes the code brittle and prone to future bugs.
Addressing these points will improve the correctness and maintainability of the new feature. The rest of the changes, including the tests and benchmark code, look solid.
| class SiluMulBlockQuantPattern: | ||
| """ | ||
| This pattern fuses silu_and_mul & block quantization. | ||
| Handles all group_size, col_major, and e8m0 variants in one pattern. | ||
| """ | ||
| def __init__(self): | ||
| self.quant_dtype = FP8_DTYPE | ||
|
|
||
| from vllm.config import get_current_vllm_config | ||
| config = get_current_vllm_config() | ||
| self.model_dtype = config.model_config.dtype if config.model_config else None | ||
|
|
||
| from .matcher_utils import MatcherSiluAndMul, MatcherQuantFP8 | ||
| self.silu_and_mul_matcher = MatcherSiluAndMul() | ||
|
|
||
| # Create a single matcher for group_size=128 as the pattern template | ||
| # The actual replacement will handle all variants | ||
| scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) | ||
| quant_key = QuantKey(dtype=FP8_DTYPE, scale=scale, symmetric=True) | ||
| self.quant_matcher = MatcherQuantFP8(quant_key, has_col_major_scales=False, is_e8m0=False) | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass) -> None: | ||
| def pattern(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| print(f"PATTERN EXECUTING - input device: {input.device}, dtype: {input.dtype}, shape: {input.shape}") | ||
|
|
||
| # Write the FULL pattern explicitly - no matchers | ||
| d = input.shape[-1] // 2 | ||
| gate = input[..., :d] | ||
| silu = torch.nn.functional.silu(gate) | ||
| up = input[..., d:] | ||
| silu_out = silu * up | ||
|
|
||
| # Match the in-place quantization pattern | ||
| x_q = torch.empty(silu_out.shape, dtype=FP8_DTYPE, device=input.device) | ||
| num_groups = silu_out.shape[-1] // 128 | ||
| x_s = torch.empty((silu_out.shape[0], num_groups), dtype=torch.float32, device=input.device) | ||
|
|
||
| torch.ops._C.per_token_group_fp8_quant(silu_out, x_q, x_s, 128, 1e-10, -448.0, 448.0, False) | ||
|
|
||
| return x_q, x_s | ||
|
|
||
| def replacement(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| print(f"FUSED KERNEL TRIGGERED! input.shape={input.shape}") | ||
|
|
||
| output_shape = list(input.shape) | ||
| output_shape[-1] = output_shape[-1] // 2 | ||
|
|
||
| result = torch.empty(output_shape, device=input.device, dtype=self.quant_dtype) | ||
| num_groups = output_shape[-1] // 128 | ||
| scale = torch.empty((output_shape[0], num_groups), dtype=torch.float32, device=input.device) | ||
|
|
||
| torch.ops._C.silu_and_mul_per_block_quant.default( | ||
| result, input, scale, 128, None, False | ||
| ) | ||
|
|
||
| return result, scale | ||
|
|
||
| print("About to trace pattern...") | ||
| input = torch.empty(5, 256, dtype=torch.float16, device='cuda') | ||
| pattern(input) | ||
| print("Pattern traced, registering replacement...") | ||
|
|
||
| register_replacement(pattern, replacement, [input], fwd_only, pm_pass) | ||
| print("Replacement registered!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SiluMulBlockQuantPattern is hardcoded to only work for group_size=128. The docstring on line 180 claims it handles all variants, but both the pattern and replacement functions explicitly use 128. This will prevent the fusion from triggering for other supported group sizes like 64.
To fix this, you should either create separate patterns for each supported group_size or generalize the pattern to capture the group_size from the graph and use it in the replacement.
For example, you could create a SiluMulBlockQuantPattern for each group size:
class SiluMulBlockQuantPattern:
def __init__(self, group_size: int):
self.group_size = group_size
...
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(...):
# use self.group_size
...
torch.ops._C.per_token_group_fp8_quant(..., self.group_size, ...)
...
def replacement(...):
# use self.group_size
...
torch.ops._C.silu_and_mul_per_block_quant.default(..., self.group_size, ...)
...
# In ActivationQuantFusionPass
...
for group_size in [64, 128]:
pattern = SiluMulBlockQuantPattern(group_size)
pattern.register(self.patterns)
...This would ensure that the fusion works for all supported configurations.
| : scales + token_idx * num_groups; | ||
|
|
||
| // Shared memory | ||
| __shared__ float shared_max[1024]; // Keep hardcoded for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The size of shared_max is hardcoded to 1024. This is brittle and could lead to out-of-bounds memory access if the block size is ever increased beyond 1024 in the dispatch logic. While the current dispatch logic limits the block size, this dependency is implicit and can be easily broken.
It would be safer to make this size dependent on the block size. One way to achieve this is by using extern __shared__ and calculating the required shared memory size in the host dispatch code.
For example:
In the kernel:
extern __shared__ float shared_mem[];
// ...
float* shared_max = shared_mem;In the dispatch function:
size_t shared_mem_size = block_size * sizeof(float);
kernel<<<grid, block, shared_mem_size, stream>>>(...);Note that since there is another __shared__ array, you'd need to combine them into a single extern __shared__ allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| print("Pattern traced, registering replacement...") | ||
|
|
||
| register_replacement(pattern, replacement, [input], fwd_only, pm_pass) | ||
| print("Replacement registered!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indentation bug causes print outside method
Medium Severity
The print("Replacement registered!") statement on line 240 is incorrectly indented at the class level instead of inside the register method. This causes it to execute when the class is defined rather than when register is called, which breaks the intended logging behavior and may cause confusion during debugging.
| shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]); | ||
| } | ||
| __syncthreads(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduction algorithm assumes power-of-2 block size
Medium Severity
The reduction algorithm on lines 85-91 assumes blockDim.x is a power of 2. If the block size is not a power of 2, some threads' local_max values won't be included in the final reduction, leading to incorrect scale computation. While the current code always uses power-of-2 block sizes (256, 512, 1024), this assumption is fragile and could break if the block size selection logic changes.
|
Hello @ProExpertProg. I've created the kernel for SiluMul+BlockQuant fusion, and it's working fine(yet, not performant enough). I'm still having some issues with the fusion pass and pattern matching, which I'll be working on. I want to get some feedback on the kernel and how you think it can be made more optimized and efficient. @ElizaWszola, I used your #27883 PR as a good reference to get some understanding of the internal workings. Thanks for it. And, if you can also share some review on the kernel, what I missed, etc., it'll be really helpful. I see and fix the ones raised by the bots shortly. |
Purpose
CUDA kernel and fusion code for Fused SiluMul+Groupwise FP8-Quantization. For #27847
Test Result
The experiments are done on NVIDIA GeForce RTX 4070 and CUDA Version: 13.0.
Test fused op:
Microbenchmark isolated op:
python benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py [------------------------------------------------------ silu-mul-block-quant ------------------------------------------------------] | unfused_fp8_impl | unfused_groupwise_fp8_impl | fused_groupwise_fp8_impl 1 threads: ------------------------------------------------------------------------------------------------------------------------- N 16 x D 1024 x DT torch.float16 x GS 64 | 218.6 | 234.3 | 92.2 N 16 x D 1024 x DT torch.float16 x GS 128 | 218.6 | 234.4 | 93.5 N 16 x D 1024 x DT torch.bfloat16 x GS 64 | 219.0 | 235.1 | 93.2 N 16 x D 1024 x DT torch.bfloat16 x GS 128 | 218.9 | 234.2 | 93.4 N 16 x D 2048 x DT torch.float16 x GS 64 | 219.1 | 235.4 | 93.6 N 16 x D 2048 x DT torch.float16 x GS 128 | 219.3 | 234.0 | 93.4 N 16 x D 2048 x DT torch.bfloat16 x GS 64 | 219.4 | 235.0 | 93.4 N 16 x D 2048 x DT torch.bfloat16 x GS 128 | 219.1 | 235.4 | 93.0 N 16 x D 4096 x DT torch.float16 x GS 64 | 219.1 | 235.2 | 93.1 N 16 x D 4096 x DT torch.float16 x GS 128 | 218.5 | 235.7 | 93.0 N 16 x D 4096 x DT torch.bfloat16 x GS 64 | 218.8 | 234.7 | 93.4 N 16 x D 4096 x DT torch.bfloat16 x GS 128 | 218.2 | 234.9 | 92.6 N 16 x D 5120 x DT torch.float16 x GS 64 | 218.7 | 233.9 | 113.5 N 16 x D 5120 x DT torch.float16 x GS 128 | 219.4 | 233.8 | 92.4 N 16 x D 5120 x DT torch.bfloat16 x GS 64 | 218.7 | 234.4 | 113.4 N 16 x D 5120 x DT torch.bfloat16 x GS 128 | 218.3 | 234.0 | 92.3 N 16 x D 14336 x DT torch.float16 x GS 64 | 219.0 | 235.8 | 315.8 N 16 x D 14336 x DT torch.float16 x GS 128 | 219.2 | 234.5 | 157.3 N 16 x D 14336 x DT torch.bfloat16 x GS 64 | 218.7 | 233.8 | 316.2 N 16 x D 14336 x DT torch.bfloat16 x GS 128 | 218.3 | 233.9 | 156.7 N 128 x D 1024 x DT torch.float16 x GS 64 | 215.7 | 231.5 | 91.4 N 128 x D 1024 x DT torch.float16 x GS 128 | 216.0 | 232.1 | 91.3 N 128 x D 1024 x DT torch.bfloat16 x GS 64 | 217.1 | 231.8 | 91.2 N 128 x D 1024 x DT torch.bfloat16 x GS 128 | 216.0 | 231.6 | 91.0 N 128 x D 2048 x DT torch.float16 x GS 64 | 215.2 | 230.3 | 136.2 N 128 x D 2048 x DT torch.float16 x GS 128 | 215.5 | 229.7 | 91.4 N 128 x D 2048 x DT torch.bfloat16 x GS 64 | 214.5 | 230.7 | 136.3 N 128 x D 2048 x DT torch.bfloat16 x GS 128 | 216.0 | 229.2 | 91.0 N 128 x D 4096 x DT torch.float16 x GS 64 | 214.3 | 230.7 | 271.7 N 128 x D 4096 x DT torch.float16 x GS 128 | 215.1 | 230.1 | 135.4 N 128 x D 4096 x DT torch.bfloat16 x GS 64 | 213.5 | 228.2 | 271.4 N 128 x D 4096 x DT torch.bfloat16 x GS 128 | 214.6 | 243.4 | 134.9 N 128 x D 5120 x DT torch.float16 x GS 64 | 213.8 | 228.5 | 339.1 N 128 x D 5120 x DT torch.float16 x GS 128 | 213.7 | 228.1 | 168.6 N 128 x D 5120 x DT torch.bfloat16 x GS 64 | 213.4 | 228.1 | 338.7 N 128 x D 5120 x DT torch.bfloat16 x GS 128 | 213.4 | 228.4 | 168.2 N 128 x D 14336 x DT torch.float16 x GS 64 | 213.8 | 230.0 | 950.2 N 128 x D 14336 x DT torch.float16 x GS 128 | 214.2 | 229.3 | 471.7 N 128 x D 14336 x DT torch.bfloat16 x GS 64 | 215.1 | 228.8 | 949.6 N 128 x D 14336 x DT torch.bfloat16 x GS 128 | 215.5 | 229.6 | 469.0 N 512 x D 1024 x DT torch.float16 x GS 64 | 213.2 | 228.4 | 271.1 N 512 x D 1024 x DT torch.float16 x GS 128 | 213.3 | 228.3 | 136.3 N 512 x D 1024 x DT torch.bfloat16 x GS 64 | 212.6 | 226.6 | 271.0 N 512 x D 1024 x DT torch.bfloat16 x GS 128 | 212.2 | 226.5 | 135.9 N 512 x D 2048 x DT torch.float16 x GS 64 | 212.1 | 226.7 | 538.8 N 512 x D 2048 x DT torch.float16 x GS 128 | 212.3 | 227.6 | 268.0 N 512 x D 2048 x DT torch.bfloat16 x GS 64 | 212.4 | 226.7 | 538.5 N 512 x D 2048 x DT torch.bfloat16 x GS 128 | 213.0 | 226.4 | 267.9 N 512 x D 4096 x DT torch.float16 x GS 64 | 211.9 | 227.1 | 1074.8 N 512 x D 4096 x DT torch.float16 x GS 128 | 212.2 | 227.2 | 533.7 N 512 x D 4096 x DT torch.bfloat16 x GS 64 | 214.0 | 227.3 | 1074.2 N 512 x D 4096 x DT torch.bfloat16 x GS 128 | 212.3 | 226.8 | 532.6 N 512 x D 5120 x DT torch.float16 x GS 64 | 211.6 | 225.5 | 1342.4 N 512 x D 5120 x DT torch.float16 x GS 128 | 211.9 | 226.3 | 665.7 N 512 x D 5120 x DT torch.bfloat16 x GS 64 | 211.3 | 225.5 | 1341.2 N 512 x D 5120 x DT torch.bfloat16 x GS 128 | 211.9 | 227.0 | 664.3 N 512 x D 14336 x DT torch.float16 x GS 64 | 215.0 | 233.2 | 3856.7 N 512 x D 14336 x DT torch.float16 x GS 128 | 215.1 | 229.5 | 1908.3 N 512 x D 14336 x DT torch.bfloat16 x GS 64 | 214.3 | 229.1 | 3858.7 N 512 x D 14336 x DT torch.bfloat16 x GS 128 | 216.2 | 229.1 | 1898.8 N 2048 x D 1024 x DT torch.float16 x GS 64 | 213.5 | 228.0 | 1015.9 N 2048 x D 1024 x DT torch.float16 x GS 128 | 213.4 | 228.2 | 509.6 N 2048 x D 1024 x DT torch.bfloat16 x GS 64 | 214.3 | 227.7 | 1015.7 N 2048 x D 1024 x DT torch.bfloat16 x GS 128 | 214.7 | 227.5 | 509.3 N 2048 x D 2048 x DT torch.float16 x GS 64 | 213.4 | 228.8 | 2022.7 N 2048 x D 2048 x DT torch.float16 x GS 128 | 213.6 | 228.6 | 1007.1 N 2048 x D 2048 x DT torch.bfloat16 x GS 64 | 213.6 | 228.6 | 2021.8 N 2048 x D 2048 x DT torch.bfloat16 x GS 128 | 213.7 | 227.9 | 1005.9 N 2048 x D 4096 x DT torch.float16 x GS 64 | 228.7 | 274.0 | 4369.1 N 2048 x D 4096 x DT torch.float16 x GS 128 | 229.3 | 230.9 | 2181.3 N 2048 x D 4096 x DT torch.bfloat16 x GS 64 | 229.2 | 231.5 | 4384.8 N 2048 x D 4096 x DT torch.bfloat16 x GS 128 | 229.7 | 229.0 | 2177.9 N 2048 x D 5120 x DT torch.float16 x GS 64 | 292.9 | 333.1 | 5611.7 N 2048 x D 5120 x DT torch.float16 x GS 128 | 294.2 | 282.5 | 2785.2 N 2048 x D 5120 x DT torch.bfloat16 x GS 64 | 293.7 | 286.9 | 5604.8 N 2048 x D 5120 x DT torch.bfloat16 x GS 128 | 294.2 | 264.7 | 2787.5 N 2048 x D 14336 x DT torch.float16 x GS 64 | 997.0 | 969.7 | 15687.4 N 2048 x D 14336 x DT torch.float16 x GS 128 | 996.6 | 846.6 | 7829.7 N 2048 x D 14336 x DT torch.bfloat16 x GS 64 | 996.1 | 854.2 | 15697.8 N 2048 x D 14336 x DT torch.bfloat16 x GS 128 | 996.1 | 845.9 | 7787.2 Times are in microseconds (us).Note: This PR is a work in progress.