[None][fix] Enable MiniMax M3 piecewise CUDA graphs#15923
Conversation
bfdfb57 to
31496ff
Compare
Wrap the MiniMax M3 metadata- and cache-dependent attention core in an inplace custom op so torch.compile can split it out of piecewise CUDA graphs. Keep QKV/index projections, QK normalization, RoPE, and the output projection visible to the compiled graph. Write dense and sparse attention results into the custom-op output buffer. Preserve FP32 sparse GQA accumulation until the final copy/cast, and expose the output buffer through MiniMaxM3SparseRuntimeBackend.forward. Register attention boundaries and mutation metadata through optional TRT-LLM op lookup, matching the latest GDN registration pattern from PR NVIDIA#15594. This avoids depending on model-specific custom ops being imported when compilation utilities initialize. Track piecewise runners owned by the compile backend and reset their CUDA graphs, captured addresses, outputs, and warmup state when phase-1 KV-cache estimation is released. Phase 2 then recaptures against the final allocations instead of replaying stale graph pointers. Add an 8-GPU MiniMax-M3-MXFP8 torch.compile E2E variant covering TP8/EP8, attention DP, TRTLLM MoE, padding CUDA graphs, multi-stream piecewise capture, and phase-2 recapture. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
31496ff to
c333655
Compare
📝 WalkthroughWalkthroughAdds an optional preallocated output tensor parameter threaded through MiniMax-M3 sparse attention functions, runtime backend, and dense/sparse attention cores, with a compiled custom-op boundary for the dense path. Also adds tracking and cleanup of piecewise CUDA graph runners in the torch.compile backend, invoked from model engine cleanup. ChangesMiniMax-M3 output tensor plumbing
Piecewise CUDA graph runner tracking and cleanup
Estimated code review effort: 4 (Complex) | ~60 minutes Sequence Diagram(s)sequenceDiagram
participant Caller
participant RuntimeBackend as MiniMaxM3SparseRuntimeBackend
participant Decode as minimax_m3_sparse_decode/prefill
participant Masked as _sparse_gqa_masked
Caller->>RuntimeBackend: forward(forward_args, output, kwargs)
RuntimeBackend->>RuntimeBackend: merge forward_args/kwargs, resolve output
RuntimeBackend->>Decode: forward_sparse(output=resolved_output)
Decode->>Masked: _sparse_gqa_masked(output=resolved_output)
Masked-->>Decode: writes result into output tensor
sequenceDiagram
participant ModelEngine as PyTorchModelEngine
participant Backend
participant Runner as PiecewiseRunner
ModelEngine->>Backend: clear_piecewise_cuda_graphs()
Backend->>Runner: clear_cuda_graphs() (for each tracked runner)
Runner-->>Backend: resets captured graphs and cached state
Suggested reviewers: 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/bot run --disable-fail-fast |
|
PR_Github #57464 [ run ] triggered by Bot. Commit: |
|
PR_Github #57464 [ run ] completed with state
|
Wrap the MiniMax M3 metadata- and cache-dependent attention core in an inplace custom op so torch.compile can split it out of piecewise CUDA graphs. Keep QKV/index projections, QK normalization, RoPE, and the output projection visible to the compiled graph.
Write dense and sparse attention results into the custom-op output buffer. Preserve FP32 sparse GQA accumulation until the final copy/cast, and expose the output buffer through MiniMaxM3SparseRuntimeBackend.forward.
Register attention boundaries and mutation metadata through optional TRT-LLM op lookup, matching the latest GDN registration pattern from PR #15594. This avoids depending on model-specific custom ops being imported when compilation utilities initialize.
Track piecewise runners owned by the compile backend and reset their CUDA graphs, captured addresses, outputs, and warmup state when phase-1 KV-cache estimation is released. Phase 2 then recaptures against the final allocations instead of replaying stale graph pointers.
Add an 8-GPU MiniMax-M3-MXFP8 torch.compile E2E variant covering TP8/EP8, attention DP, TRTLLM MoE, padding CUDA graphs, multi-stream piecewise capture, and phase-2 recapture.
Summary by CodeRabbit
New Features
Bug Fixes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.