Skip to content

Conversation

@Monishver11
Copy link

@Monishver11 Monishver11 commented Jan 24, 2026

Purpose

CUDA kernel and fusion code for Fused SiluMul+Groupwise FP8-Quantization. For #27847

Test Result

The experiments are done on NVIDIA GeForce RTX 4070 and CUDA Version: 13.0.

Test fused op:

pytest tests/kernels/core/test_fused_silu_mul_block_quant.py

(vllm-dev) [mc10322@cuda5 vllm]$ pytest tests/kernels/core/test_fused_silu_mul_block_quant.py
============================================================ test session starts ============================================================
platform linux -- Python 3.10.19, pytest-8.3.5, pluggy-1.6.0
rootdir: /scratch/mc10322/vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 330 items

tests/kernels/core/test_fused_silu_mul_block_quant.py ............................................................................... [ 23%]
..................................................................................................................................... [ 64%]
......................................................................................................................                [100%]

============================================================= warnings summary ==============================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================ 330 passed, 2 warnings in 98.97s (0:01:38) =================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Microbenchmark isolated op:

python benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py

[------------------------------------------------------ silu-mul-block-quant ------------------------------------------------------]
                                                     |  unfused_fp8_impl  |  unfused_groupwise_fp8_impl  |  fused_groupwise_fp8_impl
1 threads: -------------------------------------------------------------------------------------------------------------------------
      N 16 x D 1024 x DT torch.float16 x GS 64       |       218.6        |            234.3             |             92.2
      N 16 x D 1024 x DT torch.float16 x GS 128      |       218.6        |            234.4             |             93.5
      N 16 x D 1024 x DT torch.bfloat16 x GS 64      |       219.0        |            235.1             |             93.2
      N 16 x D 1024 x DT torch.bfloat16 x GS 128     |       218.9        |            234.2             |             93.4
      N 16 x D 2048 x DT torch.float16 x GS 64       |       219.1        |            235.4             |             93.6
      N 16 x D 2048 x DT torch.float16 x GS 128      |       219.3        |            234.0             |             93.4
      N 16 x D 2048 x DT torch.bfloat16 x GS 64      |       219.4        |            235.0             |             93.4
      N 16 x D 2048 x DT torch.bfloat16 x GS 128     |       219.1        |            235.4             |             93.0
      N 16 x D 4096 x DT torch.float16 x GS 64       |       219.1        |            235.2             |             93.1
      N 16 x D 4096 x DT torch.float16 x GS 128      |       218.5        |            235.7             |             93.0
      N 16 x D 4096 x DT torch.bfloat16 x GS 64      |       218.8        |            234.7             |             93.4
      N 16 x D 4096 x DT torch.bfloat16 x GS 128     |       218.2        |            234.9             |             92.6
      N 16 x D 5120 x DT torch.float16 x GS 64       |       218.7        |            233.9             |            113.5
      N 16 x D 5120 x DT torch.float16 x GS 128      |       219.4        |            233.8             |             92.4
      N 16 x D 5120 x DT torch.bfloat16 x GS 64      |       218.7        |            234.4             |            113.4
      N 16 x D 5120 x DT torch.bfloat16 x GS 128     |       218.3        |            234.0             |             92.3
      N 16 x D 14336 x DT torch.float16 x GS 64      |       219.0        |            235.8             |            315.8
      N 16 x D 14336 x DT torch.float16 x GS 128     |       219.2        |            234.5             |            157.3
      N 16 x D 14336 x DT torch.bfloat16 x GS 64     |       218.7        |            233.8             |            316.2
      N 16 x D 14336 x DT torch.bfloat16 x GS 128    |       218.3        |            233.9             |            156.7
      N 128 x D 1024 x DT torch.float16 x GS 64      |       215.7        |            231.5             |             91.4
      N 128 x D 1024 x DT torch.float16 x GS 128     |       216.0        |            232.1             |             91.3
      N 128 x D 1024 x DT torch.bfloat16 x GS 64     |       217.1        |            231.8             |             91.2
      N 128 x D 1024 x DT torch.bfloat16 x GS 128    |       216.0        |            231.6             |             91.0
      N 128 x D 2048 x DT torch.float16 x GS 64      |       215.2        |            230.3             |            136.2
      N 128 x D 2048 x DT torch.float16 x GS 128     |       215.5        |            229.7             |             91.4
      N 128 x D 2048 x DT torch.bfloat16 x GS 64     |       214.5        |            230.7             |            136.3
      N 128 x D 2048 x DT torch.bfloat16 x GS 128    |       216.0        |            229.2             |             91.0
      N 128 x D 4096 x DT torch.float16 x GS 64      |       214.3        |            230.7             |            271.7
      N 128 x D 4096 x DT torch.float16 x GS 128     |       215.1        |            230.1             |            135.4
      N 128 x D 4096 x DT torch.bfloat16 x GS 64     |       213.5        |            228.2             |            271.4
      N 128 x D 4096 x DT torch.bfloat16 x GS 128    |       214.6        |            243.4             |            134.9
      N 128 x D 5120 x DT torch.float16 x GS 64      |       213.8        |            228.5             |            339.1
      N 128 x D 5120 x DT torch.float16 x GS 128     |       213.7        |            228.1             |            168.6
      N 128 x D 5120 x DT torch.bfloat16 x GS 64     |       213.4        |            228.1             |            338.7
      N 128 x D 5120 x DT torch.bfloat16 x GS 128    |       213.4        |            228.4             |            168.2
      N 128 x D 14336 x DT torch.float16 x GS 64     |       213.8        |            230.0             |            950.2
      N 128 x D 14336 x DT torch.float16 x GS 128    |       214.2        |            229.3             |            471.7
      N 128 x D 14336 x DT torch.bfloat16 x GS 64    |       215.1        |            228.8             |            949.6
      N 128 x D 14336 x DT torch.bfloat16 x GS 128   |       215.5        |            229.6             |            469.0
      N 512 x D 1024 x DT torch.float16 x GS 64      |       213.2        |            228.4             |            271.1
      N 512 x D 1024 x DT torch.float16 x GS 128     |       213.3        |            228.3             |            136.3
      N 512 x D 1024 x DT torch.bfloat16 x GS 64     |       212.6        |            226.6             |            271.0
      N 512 x D 1024 x DT torch.bfloat16 x GS 128    |       212.2        |            226.5             |            135.9
      N 512 x D 2048 x DT torch.float16 x GS 64      |       212.1        |            226.7             |            538.8
      N 512 x D 2048 x DT torch.float16 x GS 128     |       212.3        |            227.6             |            268.0
      N 512 x D 2048 x DT torch.bfloat16 x GS 64     |       212.4        |            226.7             |            538.5
      N 512 x D 2048 x DT torch.bfloat16 x GS 128    |       213.0        |            226.4             |            267.9
      N 512 x D 4096 x DT torch.float16 x GS 64      |       211.9        |            227.1             |           1074.8
      N 512 x D 4096 x DT torch.float16 x GS 128     |       212.2        |            227.2             |            533.7
      N 512 x D 4096 x DT torch.bfloat16 x GS 64     |       214.0        |            227.3             |           1074.2
      N 512 x D 4096 x DT torch.bfloat16 x GS 128    |       212.3        |            226.8             |            532.6
      N 512 x D 5120 x DT torch.float16 x GS 64      |       211.6        |            225.5             |           1342.4
      N 512 x D 5120 x DT torch.float16 x GS 128     |       211.9        |            226.3             |            665.7
      N 512 x D 5120 x DT torch.bfloat16 x GS 64     |       211.3        |            225.5             |           1341.2
      N 512 x D 5120 x DT torch.bfloat16 x GS 128    |       211.9        |            227.0             |            664.3
      N 512 x D 14336 x DT torch.float16 x GS 64     |       215.0        |            233.2             |           3856.7
      N 512 x D 14336 x DT torch.float16 x GS 128    |       215.1        |            229.5             |           1908.3
      N 512 x D 14336 x DT torch.bfloat16 x GS 64    |       214.3        |            229.1             |           3858.7
      N 512 x D 14336 x DT torch.bfloat16 x GS 128   |       216.2        |            229.1             |           1898.8
      N 2048 x D 1024 x DT torch.float16 x GS 64     |       213.5        |            228.0             |           1015.9
      N 2048 x D 1024 x DT torch.float16 x GS 128    |       213.4        |            228.2             |            509.6
      N 2048 x D 1024 x DT torch.bfloat16 x GS 64    |       214.3        |            227.7             |           1015.7
      N 2048 x D 1024 x DT torch.bfloat16 x GS 128   |       214.7        |            227.5             |            509.3
      N 2048 x D 2048 x DT torch.float16 x GS 64     |       213.4        |            228.8             |           2022.7
      N 2048 x D 2048 x DT torch.float16 x GS 128    |       213.6        |            228.6             |           1007.1
      N 2048 x D 2048 x DT torch.bfloat16 x GS 64    |       213.6        |            228.6             |           2021.8
      N 2048 x D 2048 x DT torch.bfloat16 x GS 128   |       213.7        |            227.9             |           1005.9
      N 2048 x D 4096 x DT torch.float16 x GS 64     |       228.7        |            274.0             |           4369.1
      N 2048 x D 4096 x DT torch.float16 x GS 128    |       229.3        |            230.9             |           2181.3
      N 2048 x D 4096 x DT torch.bfloat16 x GS 64    |       229.2        |            231.5             |           4384.8
      N 2048 x D 4096 x DT torch.bfloat16 x GS 128   |       229.7        |            229.0             |           2177.9
      N 2048 x D 5120 x DT torch.float16 x GS 64     |       292.9        |            333.1             |           5611.7
      N 2048 x D 5120 x DT torch.float16 x GS 128    |       294.2        |            282.5             |           2785.2
      N 2048 x D 5120 x DT torch.bfloat16 x GS 64    |       293.7        |            286.9             |           5604.8
      N 2048 x D 5120 x DT torch.bfloat16 x GS 128   |       294.2        |            264.7             |           2787.5
      N 2048 x D 14336 x DT torch.float16 x GS 64    |       997.0        |            969.7             |          15687.4
      N 2048 x D 14336 x DT torch.float16 x GS 128   |       996.6        |            846.6             |           7829.7
      N 2048 x D 14336 x DT torch.bfloat16 x GS 64   |       996.1        |            854.2             |          15697.8
      N 2048 x D 14336 x DT torch.bfloat16 x GS 128  |       996.1        |            845.9             |           7787.2

Times are in microseconds (us).

Note: This PR is a work in progress.

Monishver Chandrasekaran and others added 30 commits January 13, 2026 02:00
…llm-project#27847)

Signed-off-by: Monishver Chandrasekaran <monishver@Monishvers-MacBook-Air.local>
…MakeLists.txt

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <mc10322@cuda5.cims.nyu.edu>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
… small batches getting 512 threads/block

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@Monishver11 Monishver11 marked this pull request as draft January 24, 2026 06:36
@mergify mergify bot added ci/build performance Performance-related issues labels Jan 24, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for SiLU, multiplication, and block-wise FP8 quantization, along with corresponding benchmarks, tests, and integration into the torch.compile fusion passes. The new kernel shows significant performance improvements in the provided benchmarks.

My review has identified a couple of important issues:

  1. A critical issue in the torch.compile fusion pass where the pattern for the new fused kernel is hardcoded for a single group_size, which will prevent fusion for other supported sizes.
  2. A high-severity issue in the CUDA kernel implementation regarding a hardcoded shared memory size, which makes the code brittle and prone to future bugs.

Addressing these points will improve the correctness and maintainability of the new feature. The rest of the changes, including the tests and benchmark code, look solid.

Comment on lines +177 to +240
class SiluMulBlockQuantPattern:
"""
This pattern fuses silu_and_mul & block quantization.
Handles all group_size, col_major, and e8m0 variants in one pattern.
"""
def __init__(self):
self.quant_dtype = FP8_DTYPE

from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None

from .matcher_utils import MatcherSiluAndMul, MatcherQuantFP8
self.silu_and_mul_matcher = MatcherSiluAndMul()

# Create a single matcher for group_size=128 as the pattern template
# The actual replacement will handle all variants
scale = ScaleDesc(torch.float32, False, GroupShape(1, 128))
quant_key = QuantKey(dtype=FP8_DTYPE, scale=scale, symmetric=True)
self.quant_matcher = MatcherQuantFP8(quant_key, has_col_major_scales=False, is_e8m0=False)

def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
print(f"PATTERN EXECUTING - input device: {input.device}, dtype: {input.dtype}, shape: {input.shape}")

# Write the FULL pattern explicitly - no matchers
d = input.shape[-1] // 2
gate = input[..., :d]
silu = torch.nn.functional.silu(gate)
up = input[..., d:]
silu_out = silu * up

# Match the in-place quantization pattern
x_q = torch.empty(silu_out.shape, dtype=FP8_DTYPE, device=input.device)
num_groups = silu_out.shape[-1] // 128
x_s = torch.empty((silu_out.shape[0], num_groups), dtype=torch.float32, device=input.device)

torch.ops._C.per_token_group_fp8_quant(silu_out, x_q, x_s, 128, 1e-10, -448.0, 448.0, False)

return x_q, x_s

def replacement(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
print(f"FUSED KERNEL TRIGGERED! input.shape={input.shape}")

output_shape = list(input.shape)
output_shape[-1] = output_shape[-1] // 2

result = torch.empty(output_shape, device=input.device, dtype=self.quant_dtype)
num_groups = output_shape[-1] // 128
scale = torch.empty((output_shape[0], num_groups), dtype=torch.float32, device=input.device)

torch.ops._C.silu_and_mul_per_block_quant.default(
result, input, scale, 128, None, False
)

return result, scale

print("About to trace pattern...")
input = torch.empty(5, 256, dtype=torch.float16, device='cuda')
pattern(input)
print("Pattern traced, registering replacement...")

register_replacement(pattern, replacement, [input], fwd_only, pm_pass)
print("Replacement registered!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The SiluMulBlockQuantPattern is hardcoded to only work for group_size=128. The docstring on line 180 claims it handles all variants, but both the pattern and replacement functions explicitly use 128. This will prevent the fusion from triggering for other supported group sizes like 64.

To fix this, you should either create separate patterns for each supported group_size or generalize the pattern to capture the group_size from the graph and use it in the replacement.

For example, you could create a SiluMulBlockQuantPattern for each group size:

class SiluMulBlockQuantPattern:
    def __init__(self, group_size: int):
        self.group_size = group_size
        ...

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(...):
            # use self.group_size
            ...
            torch.ops._C.per_token_group_fp8_quant(..., self.group_size, ...)
            ...
        
        def replacement(...):
            # use self.group_size
            ...
            torch.ops._C.silu_and_mul_per_block_quant.default(..., self.group_size, ...)
            ...

# In ActivationQuantFusionPass
...
for group_size in [64, 128]:
    pattern = SiluMulBlockQuantPattern(group_size)
    pattern.register(self.patterns)
...

This would ensure that the fusion works for all supported configurations.

: scales + token_idx * num_groups;

// Shared memory
__shared__ float shared_max[1024]; // Keep hardcoded for now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The size of shared_max is hardcoded to 1024. This is brittle and could lead to out-of-bounds memory access if the block size is ever increased beyond 1024 in the dispatch logic. While the current dispatch logic limits the block size, this dependency is implicit and can be easily broken.

It would be safer to make this size dependent on the block size. One way to achieve this is by using extern __shared__ and calculating the required shared memory size in the host dispatch code.

For example:
In the kernel:

extern __shared__ float shared_mem[];
// ...
float* shared_max = shared_mem;

In the dispatch function:

size_t shared_mem_size = block_size * sizeof(float);
kernel<<<grid, block, shared_mem_size, stream>>>(...);

Note that since there is another __shared__ array, you'd need to combine them into a single extern __shared__ allocation.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

print("Pattern traced, registering replacement...")

register_replacement(pattern, replacement, [input], fwd_only, pm_pass)
print("Replacement registered!")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation bug causes print outside method

Medium Severity

The print("Replacement registered!") statement on line 240 is incorrectly indented at the class level instead of inside the register method. This causes it to execute when the class is defined rather than when register is called, which breaks the intended logging behavior and may cause confusion during debugging.

Fix in Cursor Fix in Web

shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]);
}
__syncthreads();
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduction algorithm assumes power-of-2 block size

Medium Severity

The reduction algorithm on lines 85-91 assumes blockDim.x is a power of 2. If the block size is not a power of 2, some threads' local_max values won't be included in the final reduction, leading to incorrect scale computation. While the current code always uses power-of-2 block sizes (256, 512, 1024), this assumption is fragile and could break if the block size selection logic changes.

Fix in Cursor Fix in Web

@Monishver11
Copy link
Author

Hello @ProExpertProg. I've created the kernel for SiluMul+BlockQuant fusion, and it's working fine(yet, not performant enough). I'm still having some issues with the fusion pass and pattern matching, which I'll be working on. I want to get some feedback on the kernel and how you think it can be made more optimized and efficient.

@ElizaWszola, I used your #27883 PR as a good reference to get some understanding of the internal workings. Thanks for it. And, if you can also share some review on the kernel, what I missed, etc., it'll be really helpful.

I see and fix the ones raised by the bots shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant