[None][fix] Enable cuda_scaled_mm fast path for FP8 linear on SM121#15928
[None][fix] Enable cuda_scaled_mm fast path for FP8 linear on SM121#15928souvikDevloper wants to merge 5 commits into
Conversation
The FP8 linear cuda_scaled_mm fast path for small batch sizes (m <= 8) was gated by hard-coded allowlists covering only SM89 and SM120, silently falling back to the slower cublas path on SM121 (GB10 / DGX Spark) even though the hardware supports it. Add SM121 to both allowlists. Fixes NVIDIA#15673 Signed-off-by: souvikDevloper <gshsouvik01@gmail.com>
📝 WalkthroughWalkthroughThe compute-capability allowlist controlling the CUDA core fast path ( ChangesSM121 Support in cuda_scaled_mm Fast Path
Estimated code review effort: 1 (Trivial) | ~3 minutes Related issues: Suggested labels: bug, low-risk Suggested reviewers: Tracin 🐰 A whisker-twitch, a GPU cheer, 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)
3272-3279: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueCorrectly extends allowlist to SM121.
The added
(capability[0] == 12 and capability[1] == 1)clause correctly enables thecuda_scaled_mmfast path for SM121, matching the PR objective and the parallel change inquant.py.For consistency with
quant.py's more concisecapability in ((8, 9), (12, 0), (12, 1))pattern, consider simplifying this chained boolean expression, though this is purely stylistic.♻️ Optional simplification for consistency with quant.py
self.enable_cuda_core = False if torch.cuda.is_available(): capability = torch.cuda.get_device_capability( torch.device('cuda:0')) - # enable cuda core for sm89, sm120 and sm121 - self.enable_cuda_core = (capability[0] == 8 and capability[1] == 9) \ - or (capability[0] == 12 and capability[1] == 0) \ - or (capability[0] == 12 and capability[1] == 1) + # enable cuda core for sm89, sm120 and sm121 + self.enable_cuda_core = capability in ((8, 9), (12, 0), (12, 1))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/linear.py` around lines 3272 - 3279, The SM121 allowlist in the cuda core gating logic is already correct, but the boolean chain in the `enable_cuda_core` assignment inside the linear module should be simplified for consistency with `quant.py`. Update the capability check in `linear.py`’s CUDA availability block to use the same concise membership-style pattern as `quant.py`, keeping the allowlist for `(8, 9)`, `(12, 0)`, and `(12, 1)` while preserving the existing `cuda_scaled_mm` fast-path behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 3272-3279: The SM121 allowlist in the cuda core gating logic is
already correct, but the boolean chain in the `enable_cuda_core` assignment
inside the linear module should be simplified for consistency with `quant.py`.
Update the capability check in `linear.py`’s CUDA availability block to use the
same concise membership-style pattern as `quant.py`, keeping the allowlist for
`(8, 9)`, `(12, 0)`, and `(12, 1)` while preserving the existing
`cuda_scaled_mm` fast-path behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7847ad05-ea4f-4f26-bef8-b7f46414632c
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.pytensorrt_llm/_torch/modules/linear.py
Simplify the chained boolean in Linear.enable_cuda_core to the same 'capability in (...)' pattern used in the auto_deploy quant op, per review feedback. Signed-off-by: souvikDevloper <gshsouvik01@gmail.com>
Description
Fixes #15673
On SM121 devices (GB10 / DGX Spark), the FP8 linear
cuda_scaled_mmfast path for small batch sizes (m <= 8) is silently disabled, because the twoenable_cuda_coreallowlists are hard-coded to SM89 and SM120 only. As a result, decode-time GEMMs on SM121 always fall back to the slower genericcublas_scaled_mmpath, with no warning or log message.The comment at the first site ("enable cuda core for sm89 and sm120") suggests SM121 was left out unintentionally rather than deliberately excluded — SM121 supports the same fast path as SM120.
Changes
1.
tensorrt_llm/_torch/modules/linear.py— used by the PyTorch backendLinearmodule:2.
tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py— same gate in the AutoDeploy FP8 prequant linear op:With
enable_cuda_corenow true on SM121, batches withm <= 8take the optimizedtorch.ops.trtllm.cuda_scaled_mmpath instead of falling through tocublas_scaled_mm, matching SM120 behavior.Test Coverage
No new tests added: the change only extends an existing device-capability allowlist, and the affected path is already exercised by the existing FP8 linear unit tests on supported hardware. Verifying the SM121 branch end-to-end requires GB10 / DGX Spark hardware.
PR Checklist
[None][fix] Enable cuda_scaled_mm fast path for FP8 linear on SM121Summary by CodeRabbit