[None][perf] Fuse Qwen3.5 GDN input projections#15884
Conversation
Signed-off-by: Mingyang Hao <200044211+mingyangHao@users.noreply.github.com>
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis PR adds a combined ChangesCombined GDN Projection and Strided Kernel Support
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant Checkpoint as HF Checkpoint
participant Mapper as Qwen3NextHfWeightMapper
participant Config as ModelConfig
participant GDN as Qwen3NextGatedDeltaNet
Checkpoint->>Mapper: preprocess_weights(weights)
Mapper->>Config: check enable_attention_dp
alt attention-DP enabled
Mapper->>Mapper: _combine_gdn_input_projections
Mapper-->>Checkpoint: weights with in_proj_qkvzba
else attention-DP disabled
Mapper-->>Checkpoint: weights with separate in_proj_qkvz/in_proj_ba
end
Checkpoint->>GDN: load state_dict
GDN->>GDN: init use_combined_qkvzba_projection
GDN->>GDN: forward(hidden_states)
alt combined projection
GDN->>GDN: slice projected_states_qkvzba into mixed_qkv, z, b, a
else split projection
GDN->>GDN: compute qkvz + ba, fuse/split into mixed_qkv, z, b, a
end
GDN-->>Checkpoint: attention output
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py (1)
46-55: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winCast
dst_offsetstotl.int64too. The store index has the same overflow risk assrc_offsets; with largenum_prefill_tokens * conv_dim,conv_offsets[:, None] * num_prefill_tokenscan wrap and write to the wrong address.Proposed fix
- dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :] + dst_offsets = conv_offsets[:, None].to(tl.int64) * num_prefill_tokens + seq_offsets[None, :]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py` around lines 46 - 55, The dst_offsets calculation in fuse_elementwise_ops has the same overflow risk as src_offsets because conv_offsets[:, None] * num_prefill_tokens can exceed INT32_MAX and corrupt the store address. Update the dst_offsets expression in the same kernel path to cast the operands to tl.int64 before the multiply/add, mirroring the existing src_offsets fix, and keep the tl.store call using the widened offsets.
🧹 Nitpick comments (1)
tests/unittest/_torch/models/checkpoints/hf/test_qwen3_next_weight_mapper.py (1)
35-73: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winTest coverage is good for the happy path but misses key error branches.
_combine_gdn_input_projections(per the upstream snippet) also raisesValueErrorwhen only one ofqkvz/bais present ("Expected both QKVZ and BA tensors...") and when trailing shapes mismatch between qkvz/ba row tensors. Neither path is covered here.Suggest adding:
test_combine_gdn_input_projections_missing_projection_raises— onlyin_proj_qkvz.*present, expectValueErrormatching"Expected both QKVZ and BA".test_combine_gdn_input_projections_rejects_trailing_shape_mismatch— row-tensors with mismatchedshape[1:], expectValueErrormatching"trailing shapes do not match".Coverage for the two implemented cases (consumer-order combination, scalar-metadata rejection) is sufficient and correctly verified against the actual reshape/slice math.
As per path instructions, "Act as a QA engineer reviewing test changes and coverage for TensorRT-LLM" and "suggest concrete list file names and whether coverage is sufficient, insufficient, or needs follow-up."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/models/checkpoints/hf/test_qwen3_next_weight_mapper.py` around lines 35 - 73, Add the missing error-path coverage for _combine_gdn_input_projections in test_qwen3_next_weight_mapper.py: create a test that passes only in_proj_qkvz.* entries and asserts a ValueError matching the “Expected both QKVZ and BA” message, and another that uses row-tensor inputs with mismatched trailing shapes and asserts the “trailing shapes do not match” error. Keep the existing consumer-order and scalar-metadata tests as-is, since they already verify the happy path and non-row metadata rejection.Source: Path instructions
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py`:
- Around line 47-51: The combine logic in qwen3_next_weight_mapper should not
require paired qkvz/ba tensors until FP8 BA scale metadata has been handled.
Update the grouping/combine flow in the superclass method used by
Qwen3_5MoeHfWeightMapper so orphan in_proj_ba.weight_scale_inv entries are
dequantized or dropped alongside the BA projection, and only then enforce the
“Expected both QKVZ and BA” check for the relevant symbols in the combine path
and _dequantize_linear_attn_fp8_qkvz.
- Around line 55-90: The combined projection packing in qwen3_next_weight_mapper
should fail fast if an existing in_proj_qkvzba.<suffix> key is already present,
instead of silently overwriting it. Add the same duplicate-key check used by the
split packing path before assigning into combined_weights in the combine logic,
and raise an error when both the combined key and the split in_proj_qkvz /
in_proj_ba inputs would map to the same output.
---
Outside diff comments:
In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py`:
- Around line 46-55: The dst_offsets calculation in fuse_elementwise_ops has the
same overflow risk as src_offsets because conv_offsets[:, None] *
num_prefill_tokens can exceed INT32_MAX and corrupt the store address. Update
the dst_offsets expression in the same kernel path to cast the operands to
tl.int64 before the multiply/add, mirroring the existing src_offsets fix, and
keep the tl.store call using the widened offsets.
---
Nitpick comments:
In
`@tests/unittest/_torch/models/checkpoints/hf/test_qwen3_next_weight_mapper.py`:
- Around line 35-73: Add the missing error-path coverage for
_combine_gdn_input_projections in test_qwen3_next_weight_mapper.py: create a
test that passes only in_proj_qkvz.* entries and asserts a ValueError matching
the “Expected both QKVZ and BA” message, and another that uses row-tensor inputs
with mismatched trailing shapes and asserts the “trailing shapes do not match”
error. Keep the existing consumer-order and scalar-metadata tests as-is, since
they already verify the happy path and non-row metadata rejection.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 589c1b79-8970-4b4a-8d07-a2c4afc35ea8
📒 Files selected for processing (8)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_5_weight_mapper.pytensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.pytensorrt_llm/_torch/models/modeling_qwen3_5.pytensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.pytensorrt_llm/_torch/modules/mamba/layernorm_gated.pytests/unittest/_torch/models/checkpoints/hf/test_qwen3_next_weight_mapper.pytests/unittest/_torch/modules/mamba/test_gdn_kernel_optimizations.py
| for (prefix, suffix), tensors in grouped.items(): | ||
| if tensors.keys() != {"qkvz", "ba"}: | ||
| raise ValueError( | ||
| f"Expected both QKVZ and BA tensors for {prefix}.{suffix}, " | ||
| f"got {sorted(tensors)}") |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
Handle Qwen3.5 FP8 BA scales before requiring paired suffixes.
With attention-DP enabled, Qwen3_5MoeHfWeightMapper calls this superclass combine after _dequantize_linear_attn_fp8_qkvz() removes only in_proj_qkvz.weight_scale_inv; in_proj_ba.weight_scale_inv can remain alone and hit this Expected both QKVZ and BA error. Dequantize/drop BA scale metadata for the combined non-quantized projection, or skip orphan metadata only after ensuring the corresponding BA weight has been dequantized too.
Also applies to: 195-196
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py` around
lines 47 - 51, The combine logic in qwen3_next_weight_mapper should not require
paired qkvz/ba tensors until FP8 BA scale metadata has been handled. Update the
grouping/combine flow in the superclass method used by Qwen3_5MoeHfWeightMapper
so orphan in_proj_ba.weight_scale_inv entries are dequantized or dropped
alongside the BA projection, and only then enforce the “Expected both QKVZ and
BA” check for the relevant symbols in the combine path and
_dequantize_linear_attn_fp8_qkvz.
| combined_name = f"{prefix}.in_proj_qkvzba.{suffix}" | ||
|
|
||
| # Scalar/per-tensor metadata is shared by the two projections. It | ||
| # cannot be row-reordered, so retain one copy after validating it. | ||
| if (qkvz.ndim == 0 or ba.ndim == 0 or qkvz.shape[0] != expected_qkvz | ||
| or ba.shape[0] != expected_ba): | ||
| if qkvz.shape != ba.shape or not torch.equal(qkvz, ba): | ||
| raise ValueError( | ||
| f"Cannot combine non-row GDN projection metadata " | ||
| f"{prefix}.{suffix}: QKVZ shape={tuple(qkvz.shape)}, " | ||
| f"BA shape={tuple(ba.shape)}") | ||
| combined_weights[combined_name] = qkvz | ||
| continue | ||
|
|
||
| if qkvz.shape[1:] != ba.shape[1:]: | ||
| raise ValueError( | ||
| f"GDN projection trailing shapes do not match for " | ||
| f"{prefix}.{suffix}: {tuple(qkvz.shape)} vs {tuple(ba.shape)}" | ||
| ) | ||
|
|
||
| trailing_shape = qkvz.shape[1:] | ||
| qkvz = qkvz.reshape(num_k_heads, qkvz_group_dim, *trailing_shape) | ||
| ba = ba.reshape(num_k_heads, ba_group_dim, *trailing_shape) | ||
|
|
||
| q_end = head_k_dim | ||
| k_end = q_end + head_k_dim | ||
| v_end = k_end + heads_ratio * head_v_dim | ||
| z_end = v_end + heads_ratio * head_v_dim | ||
| q = qkvz[:, :q_end].reshape(-1, *trailing_shape) | ||
| k = qkvz[:, q_end:k_end].reshape(-1, *trailing_shape) | ||
| v = qkvz[:, k_end:v_end].reshape(-1, *trailing_shape) | ||
| z = qkvz[:, v_end:z_end].reshape(-1, *trailing_shape) | ||
| b = ba[:, :heads_ratio].reshape(-1, *trailing_shape) | ||
| a = ba[:, heads_ratio:].reshape(-1, *trailing_shape) | ||
| combined_weights[combined_name] = torch.cat((q, k, v, z, b, a), | ||
| dim=0).contiguous() |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟡 Minor | ⚡ Quick win
Reject existing combined projection keys before overwriting them.
If a checkpoint already contains in_proj_qkvzba.<suffix> plus split in_proj_qkvz/in_proj_ba, line 89 silently overwrites the original combined tensor. Mirror the duplicate-key guard used by the split packer and fail fast.
Proposed guard
qkvz = tensors["qkvz"]
ba = tensors["ba"]
combined_name = f"{prefix}.in_proj_qkvzba.{suffix}"
+ if combined_name in combined_weights:
+ raise ValueError(f"Combined projection {combined_name} already exists")
# Scalar/per-tensor metadata is shared by the two projections. It📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| combined_name = f"{prefix}.in_proj_qkvzba.{suffix}" | |
| # Scalar/per-tensor metadata is shared by the two projections. It | |
| # cannot be row-reordered, so retain one copy after validating it. | |
| if (qkvz.ndim == 0 or ba.ndim == 0 or qkvz.shape[0] != expected_qkvz | |
| or ba.shape[0] != expected_ba): | |
| if qkvz.shape != ba.shape or not torch.equal(qkvz, ba): | |
| raise ValueError( | |
| f"Cannot combine non-row GDN projection metadata " | |
| f"{prefix}.{suffix}: QKVZ shape={tuple(qkvz.shape)}, " | |
| f"BA shape={tuple(ba.shape)}") | |
| combined_weights[combined_name] = qkvz | |
| continue | |
| if qkvz.shape[1:] != ba.shape[1:]: | |
| raise ValueError( | |
| f"GDN projection trailing shapes do not match for " | |
| f"{prefix}.{suffix}: {tuple(qkvz.shape)} vs {tuple(ba.shape)}" | |
| ) | |
| trailing_shape = qkvz.shape[1:] | |
| qkvz = qkvz.reshape(num_k_heads, qkvz_group_dim, *trailing_shape) | |
| ba = ba.reshape(num_k_heads, ba_group_dim, *trailing_shape) | |
| q_end = head_k_dim | |
| k_end = q_end + head_k_dim | |
| v_end = k_end + heads_ratio * head_v_dim | |
| z_end = v_end + heads_ratio * head_v_dim | |
| q = qkvz[:, :q_end].reshape(-1, *trailing_shape) | |
| k = qkvz[:, q_end:k_end].reshape(-1, *trailing_shape) | |
| v = qkvz[:, k_end:v_end].reshape(-1, *trailing_shape) | |
| z = qkvz[:, v_end:z_end].reshape(-1, *trailing_shape) | |
| b = ba[:, :heads_ratio].reshape(-1, *trailing_shape) | |
| a = ba[:, heads_ratio:].reshape(-1, *trailing_shape) | |
| combined_weights[combined_name] = torch.cat((q, k, v, z, b, a), | |
| dim=0).contiguous() | |
| qkvz = tensors["qkvz"] | |
| ba = tensors["ba"] | |
| combined_name = f"{prefix}.in_proj_qkvzba.{suffix}" | |
| if combined_name in combined_weights: | |
| raise ValueError(f"Combined projection {combined_name} already exists") | |
| # Scalar/per-tensor metadata is shared by the two projections. It | |
| # cannot be row-reordered, so retain one copy after validating it. | |
| if (qkvz.ndim == 0 or ba.ndim == 0 or qkvz.shape[0] != expected_qkvz | |
| or ba.shape[0] != expected_ba): | |
| if qkvz.shape != ba.shape or not torch.equal(qkvz, ba): | |
| raise ValueError( | |
| f"Cannot combine non-row GDN projection metadata " | |
| f"{prefix}.{suffix}: QKVZ shape={tuple(qkvz.shape)}, " | |
| f"BA shape={tuple(ba.shape)}") | |
| combined_weights[combined_name] = qkvz | |
| continue | |
| if qkvz.shape[1:] != ba.shape[1:]: | |
| raise ValueError( | |
| f"GDN projection trailing shapes do not match for " | |
| f"{prefix}.{suffix}: {tuple(qkvz.shape)} vs {tuple(ba.shape)}" | |
| ) | |
| trailing_shape = qkvz.shape[1:] | |
| qkvz = qkvz.reshape(num_k_heads, qkvz_group_dim, *trailing_shape) | |
| ba = ba.reshape(num_k_heads, ba_group_dim, *trailing_shape) | |
| q_end = head_k_dim | |
| k_end = q_end + head_k_dim | |
| v_end = k_end + heads_ratio * head_v_dim | |
| z_end = v_end + heads_ratio * head_v_dim | |
| q = qkvz[:, :q_end].reshape(-1, *trailing_shape) | |
| k = qkvz[:, q_end:k_end].reshape(-1, *trailing_shape) | |
| v = qkvz[:, k_end:v_end].reshape(-1, *trailing_shape) | |
| z = qkvz[:, v_end:z_end].reshape(-1, *trailing_shape) | |
| b = ba[:, :heads_ratio].reshape(-1, *trailing_shape) | |
| a = ba[:, heads_ratio:].reshape(-1, *trailing_shape) | |
| combined_weights[combined_name] = torch.cat((q, k, v, z, b, a), | |
| dim=0).contiguous() |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py` around
lines 55 - 90, The combined projection packing in qwen3_next_weight_mapper
should fail fast if an existing in_proj_qkvzba.<suffix> key is already present,
instead of silently overwriting it. Add the same duplicate-key check used by the
split packing path before assigning into combined_weights in the combine logic,
and raise an error when both the combined key and the split in_proj_qkvz /
in_proj_ba inputs would map to the same output.
|
/bot run --disable-fail-fast |
|
PR_Github #57284 [ run ] triggered by Bot. Commit: |
|
PR_Github #57284 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #57328 [ run ] triggered by Bot. Commit: |
|
PR_Github #57328 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #57389 [ run ] triggered by Bot. Commit: |
|
PR_Github #57389 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #57457 [ run ] triggered by Bot. Commit: |
|
PR_Github #57457 [ run ] completed with state |
Signed-off-by: Mingyang Hao <200044211+mingyangHao@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #57538 [ run ] triggered by Bot. Commit: |
|
PR_Github #57538 [ run ] completed with state |
Summary by CodeRabbit
New Features
Bug Fixes
Perf for TP4:
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.