From c333655eb793b41325ec8eba41c92547f683fe11 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Fri, 3 Jul 2026 05:22:08 -0700 Subject: [PATCH] [None][fix] Enable MiniMax M3 piecewise CUDA graphs Wrap the MiniMax M3 metadata- and cache-dependent attention core in an inplace custom op so torch.compile can split it out of piecewise CUDA graphs. Keep QKV/index projections, QK normalization, RoPE, and the output projection visible to the compiled graph. Write dense and sparse attention results into the custom-op output buffer. Preserve FP32 sparse GQA accumulation until the final copy/cast, and expose the output buffer through MiniMaxM3SparseRuntimeBackend.forward. Register attention boundaries and mutation metadata through optional TRT-LLM op lookup, matching the latest GDN registration pattern from PR #15594. This avoids depending on model-specific custom ops being imported when compilation utilities initialize. Track piecewise runners owned by the compile backend and reset their CUDA graphs, captured addresses, outputs, and warmup state when phase-1 KV-cache estimation is released. Phase 2 then recaptures against the final allocations instead of replaying stale graph pointers. Add an 8-GPU MiniMax-M3-MXFP8 torch.compile E2E variant covering TP8/EP8, attention DP, TRTLLM MoE, padding CUDA graphs, multi-stream piecewise capture, and phase-2 recapture. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../sparse/minimax_m3/backend.py | 120 +++++++---- tensorrt_llm/_torch/compilation/backend.py | 11 +- .../_torch/compilation/piecewise_optimizer.py | 57 ++++-- tensorrt_llm/_torch/compilation/utils.py | 41 +++- .../_torch/models/modeling_minimaxm3.py | 191 +++++++++++++----- .../_torch/pyexecutor/model_engine.py | 2 + .../defs/accuracy/test_llm_api_pytorch.py | 33 +++ .../test_lists/qa/llm_function_core.txt | 1 + 8 files changed, 338 insertions(+), 118 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/minimax_m3/backend.py b/tensorrt_llm/_torch/attention_backend/sparse/minimax_m3/backend.py index d6619ce59639..090c6d7956fd 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/minimax_m3/backend.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/minimax_m3/backend.py @@ -593,6 +593,7 @@ def _sparse_gqa_masked( max_k: int, sm_scale: float, causal: bool, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Vectorized sparse GQA: mask non-selected blocks to -inf in QK. @@ -601,7 +602,9 @@ def _sparse_gqa_masked( the algorithm is CUDA-graph safe. ``num_q_heads / num_kv_heads`` Q heads share the same block_mask within a GQA group. - Returns ``[total_q, num_q_heads, head_dim]``. + Returns ``[total_q, num_q_heads * head_dim]``. The attention result + accumulates in FP32; when ``output`` is supplied, the final dtype + conversion writes directly into that tensor. The per-Q FP32 expansion of ``k_padded``/``v_padded`` is computed in slabs of ``chunk_q`` rows so the peak working set scales with @@ -637,44 +640,51 @@ def _sparse_gqa_masked( attended = block_mask_per_pos & valid_pos_per_kv # [total_q, num_kv_heads, max_k] row_has_any = attended.any(dim=-1, keepdim=True) # [total_q, num_kv_heads, 1] - # Output buffer in FP32 to preserve precision until the final cast. + # Preserve FP32 accumulation until the final cast. When a preallocated + # output is supplied this avoids allocating a q.dtype intermediate and + # then copying that intermediate into the custom op's output tensor. o = torch.empty((total_q, num_kv_heads, g, head_dim), dtype=torch.float32, device=q.device) - if total_q == 0 or max_k == 0: - return o.view(total_q, num_q_heads, head_dim).to(q.dtype) - - chunk_q = _compute_sparse_gqa_chunk_q(total_q, max_k, num_kv_heads, head_dim, g) - - for start in range(0, total_q, chunk_q): - end = min(start + chunk_q, total_q) - q_batch_row_chunk = q_batch_row[start:end] - # Per-Q K/V matrices for the chunk only — the dominant FP32 slab. - k_per_q_chunk = k_padded.to(torch.float32).index_select( - 0, q_batch_row_chunk - ) # [chunk, max_k, num_kv_heads, head_dim] - v_per_q_chunk = v_padded.to(torch.float32).index_select(0, q_batch_row_chunk) - q_grp_chunk = q_grp[start:end] - attended_chunk = attended[start:end] - row_has_any_chunk = row_has_any[start:end] - - # qk: [chunk, num_kv_heads, g, max_k] - qk_chunk = torch.einsum("ihgd,iqhd->ihgq", q_grp_chunk, k_per_q_chunk) * sm_scale - qk_chunk = qk_chunk.masked_fill(~attended_chunk.unsqueeze(2), float("-inf")) - - # Apply the mask + softmax via the OpenAI Triton kernel - # ``_sparse_softmax_kernel`` which honors the per-position - # ``attended`` mask, computes softmax in fp32, and folds in the - # all-False-row fix so captured graphs never produce NaN. - attn_chunk = triton_sparse_softmax(qk_chunk, attended_chunk) - - # o_chunk: [chunk, num_kv_heads, g, head_dim] - o_chunk = torch.einsum("ihgq,iqhd->ihgd", attn_chunk, v_per_q_chunk) - # Zero out rows that had no valid positions. - keep_chunk = row_has_any_chunk.squeeze(-1) # [chunk, num_kv_heads] - o_chunk = o_chunk * keep_chunk.unsqueeze(-1).unsqueeze(-1) - o[start:end] = o_chunk + if total_q > 0 and max_k > 0: + chunk_q = _compute_sparse_gqa_chunk_q(total_q, max_k, num_kv_heads, head_dim, g) - return o.view(total_q, num_q_heads, head_dim).to(q.dtype) + for start in range(0, total_q, chunk_q): + end = min(start + chunk_q, total_q) + q_batch_row_chunk = q_batch_row[start:end] + # Per-Q K/V matrices for the chunk only — the dominant FP32 slab. + k_per_q_chunk = k_padded.to(torch.float32).index_select( + 0, q_batch_row_chunk + ) # [chunk, max_k, num_kv_heads, head_dim] + v_per_q_chunk = v_padded.to(torch.float32).index_select(0, q_batch_row_chunk) + q_grp_chunk = q_grp[start:end] + attended_chunk = attended[start:end] + row_has_any_chunk = row_has_any[start:end] + + # qk: [chunk, num_kv_heads, g, max_k] + qk_chunk = torch.einsum("ihgd,iqhd->ihgq", q_grp_chunk, k_per_q_chunk) * sm_scale + qk_chunk = qk_chunk.masked_fill(~attended_chunk.unsqueeze(2), float("-inf")) + + # Apply the mask + softmax via the OpenAI Triton kernel + # ``_sparse_softmax_kernel`` which honors the per-position + # ``attended`` mask, computes softmax in fp32, and folds in the + # all-False-row fix so captured graphs never produce NaN. + attn_chunk = triton_sparse_softmax(qk_chunk, attended_chunk) + + # o_chunk: [chunk, num_kv_heads, g, head_dim] + o_chunk = torch.einsum("ihgq,iqhd->ihgd", attn_chunk, v_per_q_chunk) + # Zero out rows that had no valid positions. + keep_chunk = row_has_any_chunk.squeeze(-1) # [chunk, num_kv_heads] + o_chunk = o_chunk * keep_chunk.unsqueeze(-1).unsqueeze(-1) + o[start:end] = o_chunk + + o_flat = o.view(total_q, num_q_heads * head_dim) + if output is None: + return o_flat.to(q.dtype) + expected_shape = (total_q, num_q_heads * head_dim) + if tuple(output.shape) != expected_shape: + raise ValueError(f"output must have shape {expected_shape}, got {tuple(output.shape)}") + output.copy_(o_flat) + return output # --------------------------------------------------------------------------- @@ -695,6 +705,7 @@ def minimax_m3_sparse_decode( disable_index_value: bool, sm_scale: Optional[float] = None, idx_sm_scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """MiniMax-M3 sparse decode (CUDA-graph safe). @@ -759,7 +770,7 @@ def minimax_m3_sparse_decode( idx_sm_scale=idx_sm_scale, causal=False, ) - o_per_head = _sparse_gqa_masked( + o = _sparse_gqa_masked( q, k_padded, v_padded, @@ -771,8 +782,8 @@ def minimax_m3_sparse_decode( max_k=max_k, sm_scale=sm_scale, causal=False, + output=output, ) - o = o_per_head.reshape(batch, config.num_q_heads * config.head_dim).contiguous() return idx_o, o @@ -789,6 +800,7 @@ def minimax_m3_sparse_prefill( disable_index_value: bool, sm_scale: Optional[float] = None, idx_sm_scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """MiniMax-M3 sparse prefill / chunked-extend (CUDA-graph safe). @@ -856,7 +868,7 @@ def minimax_m3_sparse_prefill( idx_sm_scale=idx_sm_scale, causal=True, ) - o_per_head = _sparse_gqa_masked( + o = _sparse_gqa_masked( q, k_padded, v_padded, @@ -868,9 +880,8 @@ def minimax_m3_sparse_prefill( max_k=max_k, sm_scale=sm_scale, causal=True, + output=output, ) - total_q = int(q.shape[0]) - o = o_per_head.reshape(total_q, config.num_q_heads * config.head_dim).contiguous() return idx_o, o @@ -997,7 +1008,12 @@ def get_minimax_m3_attention_backend_cls(): Deferring the :class:`AttentionBackend` import keeps the algorithm module usable from test paths that do not need the runtime backend. """ - from ...interface import AttentionBackend, AttentionMetadata + from ...interface import ( + AttentionBackend, + AttentionForwardArgs, + AttentionMetadata, + merge_attention_forward_args, + ) metadata_cls = get_minimax_m3_attention_metadata_cls() @@ -1082,6 +1098,7 @@ def forward_sparse( m3_metadata: "MiniMaxM3SparseAttentionMetadata", sm_scale: Optional[float] = None, idx_sm_scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Execute the MiniMax-M3 sparse path end-to-end. @@ -1103,6 +1120,8 @@ def forward_sparse( token's K/V/idx_K to. ``m3_metadata`` : populated :class:`MiniMaxM3SparseAttentionMetadata`. + ``output`` : optional preallocated final output, + ``[num_tokens, num_q_heads * head_dim]``. Returns ``[num_tokens, num_q_heads * head_dim]``. """ @@ -1187,6 +1206,7 @@ def forward_sparse( disable_index_value=self.disable_index_value, sm_scale=sm_scale, idx_sm_scale=idx_sm_scale, + output=output, ) else: _, o = minimax_m3_sparse_decode( @@ -1201,6 +1221,7 @@ def forward_sparse( disable_index_value=self.disable_index_value, sm_scale=sm_scale, idx_sm_scale=idx_sm_scale, + output=output, ) return o @@ -1210,8 +1231,9 @@ def forward( k: Optional[torch.Tensor], v: Optional[torch.Tensor], metadata=None, - forward_args=None, + forward_args: Optional[AttentionForwardArgs] = None, *, + output: Optional[torch.Tensor] = None, idx_q: Optional[torch.Tensor] = None, idx_k: Optional[torch.Tensor] = None, idx_v: Optional[torch.Tensor] = None, @@ -1223,7 +1245,7 @@ def forward( m3_metadata: Optional["MiniMaxM3SparseAttentionMetadata"] = None, sm_scale: Optional[float] = None, idx_sm_scale: Optional[float] = None, - **_unused, + **kwargs, ) -> torch.Tensor: """Standard ``AttentionBackend.forward`` entry point. @@ -1237,6 +1259,15 @@ def forward( generic AttentionBackend dispatch site cannot drive this backend without supplying the index branch. """ + forward_args = merge_attention_forward_args(forward_args, kwargs) + if ( + output is not None + and forward_args.output is not None + and output is not forward_args.output + ): + raise ValueError("output was supplied both directly and through forward_args") + if output is None: + output = forward_args.output if idx_q is None or idx_k is None or m3_metadata is None: raise NotImplementedError( f"MiniMaxM3SparseRuntimeBackend.forward (layer " @@ -1277,6 +1308,7 @@ def forward( m3_metadata=m3_metadata, sm_scale=sm_scale, idx_sm_scale=idx_sm_scale, + output=output, ) return MiniMaxM3SparseRuntimeBackend diff --git a/tensorrt_llm/_torch/compilation/backend.py b/tensorrt_llm/_torch/compilation/backend.py index 064e982c06f7..5c27ce1d7661 100644 --- a/tensorrt_llm/_torch/compilation/backend.py +++ b/tensorrt_llm/_torch/compilation/backend.py @@ -1,6 +1,7 @@ import os from collections import OrderedDict from typing import List, Optional +from weakref import WeakSet import torch import torch._inductor.config as inductor_config @@ -19,7 +20,7 @@ from .patterns.ar_residual_norm import register_ar_fusions from .patterns.residual_add_norm import (register_add_norm, register_add_norm_quant) -from .piecewise_optimizer import piecewise_optimizer +from .piecewise_optimizer import PiecewiseRunner, piecewise_optimizer from .recover_pass import recover_pass from .remove_copy_pass import remove_copy_for_mutates_args @@ -56,6 +57,7 @@ def __init__( self.enable_inductor = enable_inductor self.capture_num_tokens = sorted(capture_num_tokens or []) self.piecewise_cuda_graph = enable_piecewise_cuda_graph + self._piecewise_runners: WeakSet[PiecewiseRunner] = WeakSet() self.no_optimization = False self.num_streams = max_num_streams self.events = Backend.Events() @@ -107,6 +109,10 @@ def generate_events(self, num_events: int): torch.cuda.Event() for _ in range(num_events - len(self.events)) ] + def clear_piecewise_cuda_graphs(self): + for runner in list(self._piecewise_runners): + runner.clear_cuda_graphs() + def optimize( self, gm: GraphModule, @@ -141,7 +147,7 @@ def optimize( gm.recompile() if self.piecewise_cuda_graph: - gm, num_events = piecewise_optimizer( + gm, num_events, runners = piecewise_optimizer( gm, example_inputs, self.enable_inductor, @@ -150,6 +156,7 @@ def optimize( self._graph_pool_handle, self.num_streams, ) + self._piecewise_runners.update(runners) self.generate_events(num_events) return gm elif self.enable_inductor: diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index ceaa657c8fb3..5472c674c448 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -16,7 +16,22 @@ get_piecewise_cuda_graph_flag, make_weak_ref, set_piecewise_running) from .multi_stream.auto_multi_stream import multi_stream_schedule -from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function +from .utils import (get_capture_piecewise_cuda_graph_flag, + get_optional_trtllm_op, is_call_function) + + +def _piecewise_boundary_ops(): + op_names = [ + "attn_custom_op_inplace", + "mla_custom_op_inplace", + "mla_dsa_attn_inplace", + "gdn_custom_op_inplace", + "minimax_m3_attn_custom_op_inplace", + ] + return [ + op for op in (get_optional_trtllm_op(op_name) for op_name in op_names) + if op is not None + ] class PiecewiseInterpreter(Interpreter): @@ -47,6 +62,7 @@ def __init__( self.enable_inductor = enable_inductor self.num_events = 0 self.max_num_streams = max_num_streams + self.runners: List["PiecewiseRunner"] = [] def run(self, *args): fake_args = [ @@ -89,7 +105,7 @@ def call_module(self, target, args, kwargs): self.num_events = max(self.num_events, num_events) submod.recompile() - self.module.__dict__[target] = PiecewiseRunner( + runner = PiecewiseRunner( submod, target, self.compile_time_num_tokens, @@ -102,6 +118,8 @@ def call_module(self, target, args, kwargs): self.piecewise_runner_idx == 0, self.piecewise_runner_idx == self.piecewise_runner_num - 1, ) + self.module.__dict__[target] = runner + self.runners.append(runner) self.piecewise_runner_idx += 1 return output @@ -161,6 +179,17 @@ def __init__( callable=default_callable, ) + def clear_cuda_graphs(self): + """Release captures while retaining buckets for a later warmup.""" + for entry in self.entries.values(): + if entry.cuda_graph is not None: + entry.cuda_graph.reset() + entry.cuda_graph = None + entry.warmup_count = 0 + entry.input_addresses = None + entry.output_addresses = None + entry.output = None + def __call__(self, *args): runtime_num_of_token = None if self.runtime_num_tokens_idx != None: @@ -248,7 +277,7 @@ def piecewise_optimizer( capture_num_tokens: Sequence[int], graph_pool_handle: tuple[int, int], max_num_streams: int = 1, -) -> tuple[GraphModule, int]: +) -> tuple[GraphModule, int, List[PiecewiseRunner]]: graph_pool_handle = torch.cuda.graph_pool_handle() graph = gm.graph @@ -256,25 +285,21 @@ def piecewise_optimizer( node_to_graph_id = {} idx = 0 exclude_modules_id = [] + piecewise_boundary_ops = _piecewise_boundary_ops() for node in graph.nodes: if node.op in ("output", "placeholder"): continue - if (not stop_partition and is_call_function(node, [ - torch.ops.trtllm.attn_custom_op_inplace.default, - torch.ops.trtllm.mla_custom_op_inplace.default, - torch.ops.trtllm.mla_dsa_attn_inplace.default, - torch.ops.aten.index.Tensor, - torch.ops.aten.cumsum.default, - ])): + is_boundary = is_call_function(node, piecewise_boundary_ops) + stop_target = is_call_function(node, [ + torch.ops.aten.index.Tensor, + torch.ops.aten.cumsum.default, + ]) + if not stop_partition and (is_boundary or stop_target): idx += 1 node_to_graph_id[node] = idx exclude_modules_id.append(idx) - if (node.target != torch.ops.trtllm.attn_custom_op_inplace.default - and node.target - != torch.ops.trtllm.mla_custom_op_inplace.default - and node.target - != torch.ops.trtllm.mla_dsa_attn_inplace.default): + if not is_boundary: # We only know it is safe to continue splitting after attention stop_partition = True else: @@ -300,4 +325,4 @@ def piecewise_optimizer( interpreter.run(*example_inputs) - return gm, interpreter.num_events + return gm, interpreter.num_events, interpreter.runners diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index 9ebea9cfe20f..16f5c0fd026d 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch from torch.fx import Node @@ -33,6 +33,13 @@ def is_call_function(node: Node, target: Union[List[Callable], Callable]): return node.op == "call_function" and node.target == target +def get_optional_trtllm_op(op_name: str) -> Optional[Callable]: + try: + return getattr(torch.ops.trtllm, op_name).default + except AttributeError: + return None + + _enable_piecewise_cuda_graph_capture = False @@ -66,16 +73,6 @@ def inplace_info(): 1: "out", 2: "residual" }, - torch.ops.trtllm.attn_custom_op_inplace.default: { - 1: "output", - 2: "output_sf" - }, - torch.ops.trtllm.mla_custom_op_inplace.default: { - 1: "output" - }, - torch.ops.trtllm.mla_dsa_attn_inplace.default: { - 1: "output" - }, torch.ops.trtllm.fused_qk_norm_rope.default: { 1: "qkv" }, @@ -177,4 +174,26 @@ def inplace_info(): 1: "x", 2: "residual" } + optional_inplace_infos = { + "attn_custom_op_inplace": { + 1: "output", + 2: "output_sf" + }, + "mla_custom_op_inplace": { + 1: "output" + }, + "mla_dsa_attn_inplace": { + 1: "output" + }, + "gdn_custom_op_inplace": { + 1: "output" + }, + "minimax_m3_attn_custom_op_inplace": { + 1: "output" + }, + } + for op_name, mutates_args in optional_inplace_infos.items(): + op = get_optional_trtllm_op(op_name) + if op is not None: + inplace_map[op] = mutates_args return inplace_map diff --git a/tensorrt_llm/_torch/models/modeling_minimaxm3.py b/tensorrt_llm/_torch/models/modeling_minimaxm3.py index 0b8b9f3e00ad..e5db0fd19856 100644 --- a/tensorrt_llm/_torch/models/modeling_minimaxm3.py +++ b/tensorrt_llm/_torch/models/modeling_minimaxm3.py @@ -43,7 +43,13 @@ from ..modules.linear import Linear, TensorParallelMode, copy_weight, load_weight_shard from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm -from ..utils import ActivationType, AuxStreamType, EventType +from ..utils import ( + ActivationType, + AuxStreamType, + EventType, + get_model_extra_attrs, + is_torch_compiling, +) from .modeling_utils import DecoderModel, DecoderModelForCausalLM, ModelConfig, register_auto_model # --------------------------------------------------------------------------- @@ -552,6 +558,49 @@ def forward(self, hidden_states: torch.Tensor): return self.minimax_all_reduce_rms(hidden_states, self.weight, self.eps) +def _extract_minimax_m3_attention_extra_attrs(layer_idx: str): + """Resolve runtime metadata and the registered MiniMax-M3 layer.""" + extra_attrs = get_model_extra_attrs() + assert extra_attrs is not None, "Model extra attrs is not set" + + metadata_ref = extra_attrs.get("attention_metadata") + assert metadata_ref is not None, "Attention metadata is not set" + metadata = metadata_ref() + assert isinstance(metadata, AttentionMetadata), "Invalid MiniMax-M3 attention metadata" + + attn_layers = extra_attrs.get("attn_layers") + assert attn_layers is not None, "Attention layer is not registered" + attn_layer_ref = attn_layers.get(layer_idx) + assert attn_layer_ref is not None, f"Cannot find attention layer for layer {layer_idx}" + attn_layer = attn_layer_ref() + assert isinstance(attn_layer, MiniMaxM3Attention), "Invalid MiniMax-M3 attention layer" + return metadata, attn_layer + + +@torch.library.custom_op("trtllm::minimax_m3_attn_custom_op_inplace", mutates_args=("output",)) +def minimax_m3_attn_custom_op_inplace( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx_q: Optional[torch.Tensor], + idx_k: Optional[torch.Tensor], + layer_idx: str, + output: torch.Tensor, +) -> None: + """Run MiniMax-M3 cache and attention work behind a compile boundary.""" + attn_metadata, attn_layer = _extract_minimax_m3_attention_extra_attrs(layer_idx) + num_tokens = attn_metadata.num_tokens + attn_layer._attention_core( + q[:num_tokens], + k[:num_tokens], + v[:num_tokens], + idx_q[:num_tokens] if idx_q is not None else None, + idx_k[:num_tokens] if idx_k is not None else None, + attn_metadata, + output[:num_tokens], + ) + + class MiniMaxM3Attention(Attention): """M3 attention: dense (layers 0-2) or sparse (layers 3-59). @@ -803,22 +852,11 @@ def _dense_forward( (causal for prefill, no-causal for decode). 8. Apply ``o_proj``. """ - from ..attention_backend.sparse.minimax_m3 import ( - _gather_paged_batched, - _write_main_kv_slots_to_pool, - ) - if attn_metadata is None: raise RuntimeError( f"MiniMax-M3 dense forward (layer {self.layer_idx}) requires " "attn_metadata; received None." ) - kv_cache_manager = getattr(attn_metadata, "kv_cache_manager", None) - if kv_cache_manager is None: - raise RuntimeError( - f"MiniMax-M3 dense forward (layer {self.layer_idx}) requires " - "attn_metadata.kv_cache_manager to be a MiniMaxM3KVCacheManagerV2." - ) # 1. Projections (no index branch). qkv = self.qkv_proj(hidden_states) @@ -831,7 +869,34 @@ def _dense_forward( if self.rotary_emb is not None and position_ids is not None: q, k = self.rotary_emb(position_ids, [q, k]) - num_tokens = int(hidden_states.shape[0]) + # Keep token-wise projections and the output projection visible to + # torch.compile. Only the metadata/cache-dependent attention core is + # hidden behind the inplace custom op. + o = self._forward_attention_core(q, k, v, None, None, attn_metadata) + return self.o_proj(o) + + def _dense_attention_core( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_metadata: AttentionMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Run dense cache updates and attention into ``output``.""" + from ..attention_backend.sparse.minimax_m3 import ( + _gather_paged_batched, + _write_main_kv_slots_to_pool, + ) + + kv_cache_manager = getattr(attn_metadata, "kv_cache_manager", None) + if kv_cache_manager is None: + raise RuntimeError( + f"MiniMax-M3 dense forward (layer {self.layer_idx}) requires " + "attn_metadata.kv_cache_manager to be a MiniMaxM3KVCacheManagerV2." + ) + + num_tokens = int(q.shape[0]) q_view = q.view(num_tokens, self.num_heads, self.head_dim) k_view = k.view(num_tokens, self.num_key_value_heads, self.head_dim) v_view = v.view(num_tokens, self.num_key_value_heads, self.head_dim) @@ -937,7 +1002,7 @@ def _dense_forward( # Build per-batch attention by routing each Q to its K via q_batch_row. # Easiest path: SDPA expects [batch, num_heads, q_len, head_dim]. # We process each batch row independently to keep the math straightforward. - attn_outputs = [] + output_view = output.view(-1, self.num_heads, self.head_dim) cu = m3_meta.cu_seqlens_q.to(torch.long).tolist() for b in range(batch): start, end = cu[b], cu[b + 1] @@ -956,18 +1021,7 @@ def _dense_forward( dropout_p=0.0, is_causal=False, ) # [1, H, q, d] - attn_outputs.append(out_b.squeeze(0).transpose(0, 1)) # [q, H, d] - o = ( - torch.cat(attn_outputs, dim=0) - if attn_outputs - else torch.empty( - 0, - self.num_heads, - self.head_dim, - device=q.device, - dtype=q.dtype, - ) - ) + output_view[start:end].copy_(out_b.squeeze(0).transpose(0, 1)) else: # Decode: one Q token per request at position seq_lens - 1. # Every input tensor here is already on q.device (set up by @@ -985,9 +1039,8 @@ def _dense_forward( dropout_p=0.0, is_causal=False, ) # [batch, H, 1, d] - # Drop the singleton Q-length axis. The result is already the - # ``[batch, num_heads, head_dim]`` layout the prefill branch - # produces, so the shared flatten on line below is sufficient. + # Drop the singleton Q-length axis and write the resulting + # ``[batch, num_heads, head_dim]`` tensor into the final buffer. # The prior ``.transpose(1, 2).reshape(batch, H, d)`` pattern # was wrong: with ``H != head_dim`` (M3 TP=8 has H=8, d=128) # the non-contiguous transpose forces ``reshape`` to copy the @@ -997,13 +1050,49 @@ def _dense_forward( # into ``o_proj``. Prefill is unaffected because its # ``transpose(0, 1)`` runs between q-len and num_heads axes # which the per-batch loop already laid out correctly. - o = out_b.squeeze(2) # [batch, num_heads, head_dim] + output.view(batch, self.num_heads, self.head_dim).copy_(out_b.squeeze(2)) - # Flatten to [num_tokens, num_heads * head_dim] for o_proj. - o = o.reshape(-1, self.num_heads * self.head_dim) + return output - # 8. Output projection. - return self.o_proj(o) + def _forward_attention_core( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx_q: Optional[torch.Tensor], + idx_k: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + output = q.new_empty((q.shape[0], self.num_heads * self.head_dim)) + if self.register_to_config and is_torch_compiling(): + minimax_m3_attn_custom_op_inplace( + q, + k, + v, + idx_q, + idx_k, + self.layer_idx_str, + output, + ) + else: + self._attention_core(q, k, v, idx_q, idx_k, attn_metadata, output) + return output + + def _attention_core( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx_q: Optional[torch.Tensor], + idx_k: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + if self.is_sparse_attention_layer: + assert idx_q is not None and idx_k is not None + return self._sparse_attention_core(q, k, v, idx_q, idx_k, attn_metadata, output) + assert idx_q is None and idx_k is None + return self._dense_attention_core(q, k, v, attn_metadata, output) def _sparse_forward( self, @@ -1050,13 +1139,6 @@ def _sparse_forward( f"MiniMax-M3 sparse forward (layer {self.layer_idx}) requires " "attn_metadata; received None." ) - kv_cache_manager = getattr(attn_metadata, "kv_cache_manager", None) - if kv_cache_manager is None: - raise RuntimeError( - f"MiniMax-M3 sparse forward (layer {self.layer_idx}) requires " - "attn_metadata.kv_cache_manager to be a MiniMaxM3KVCacheManagerV2." - ) - # 1. Projections. qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -1074,6 +1156,27 @@ def _sparse_forward( q, k = self.rotary_emb(position_ids, [q, k]) idx_q, idx_k = self.rotary_emb(position_ids, [idx_q, idx_k]) + o = self._forward_attention_core(q, k, v, idx_q, idx_k, attn_metadata) + return self.o_proj(o) + + def _sparse_attention_core( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + idx_q: torch.Tensor, + idx_k: torch.Tensor, + attn_metadata: AttentionMetadata, + output: torch.Tensor, + ) -> torch.Tensor: + """Run sparse cache updates and attention into ``output``.""" + kv_cache_manager = getattr(attn_metadata, "kv_cache_manager", None) + if kv_cache_manager is None: + raise RuntimeError( + f"MiniMax-M3 sparse forward (layer {self.layer_idx}) requires " + "attn_metadata.kv_cache_manager to be a MiniMaxM3KVCacheManagerV2." + ) + # 4. Get the paged-block main K/V cache + flat side index-K cache. # The base KVCacheManagerV2 layout is # ``[num_pages, kv_factor, tokens_per_block, num_kv_heads, head_dim]`` @@ -1148,11 +1251,12 @@ def _sparse_forward( "ModelConfig so the standard attention-backend dispatch selects " "the M3 sparse runtime." ) - o = self.attn.forward( + return self.attn.forward( q, k, v, None, + output=output, idx_q=idx_q, idx_k=idx_k, idx_v=None, # disable_index_value=True for M3 checkpoint @@ -1164,9 +1268,6 @@ def _sparse_forward( m3_metadata=m3_meta, ) - # 8. Output projection. - return self.o_proj(o) - class MiniMaxM3DecoderLayer(DecoderLayer): """One M3 transformer block (dense or sparse, MLP or MoE).""" diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 9781bea589e8..ac0a329ea6e0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2124,6 +2124,8 @@ def _init_model_capacity(self): self._init_max_num_tokens() def _release_cuda_graphs(self): + if self._torch_compile_backend is not None: + self._torch_compile_backend.clear_piecewise_cuda_graphs() if hasattr(self, 'cuda_graph_runner') and self.cuda_graph_runner is not None: self.cuda_graph_runner.clear() diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 641ed8c11ca4..e0bde811aff7 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -7286,6 +7286,39 @@ def test_mxfp8(self, tp_size, ep_size): task = GSM8K(model_name) task.evaluate(llm) + @pytest.mark.skip_less_device(8) + @pytest.mark.skip_less_device_memory(140000) + @parametrize_with_ids("tp_size,ep_size", [(8, 8)]) + def test_mxfp8_piecewise_cuda_graph(self, tp_size, ep_size): + model_name = "MiniMaxAI/MiniMax-M3-MXFP8" + model_path = f"{llm_models_root()}/MiniMax-M3-MXFP8" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5, + enable_block_reuse=False) + sparse_attention_config = MiniMaxM3SparseAttentionConfig() + cuda_graph_config = CudaGraphConfig( + enable_padding=True, batch_sizes=[1, 2, 4, 8, 12, 16, 24, 32]) + torch_compile_config = TorchCompileConfig( + enable_piecewise_cuda_graph=True, + capture_num_tokens=[1, 8192], + max_num_streams=3) + + with LLM(model_path, + tensor_parallel_size=tp_size, + moe_expert_parallel_size=ep_size, + enable_attention_dp=True, + moe_config=MoeConfig(backend="TRTLLM"), + kv_cache_config=kv_cache_config, + sparse_attention_config=sparse_attention_config, + cuda_graph_config=cuda_graph_config, + torch_compile_config=torch_compile_config, + max_seq_len=2048, + max_num_tokens=8192, + max_batch_size=32, + trust_remote_code=True) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.MXFP8 + task = GSM8K(model_name) + task.evaluate(llm) + @skip_pre_blackwell class TestGLM5FP8(LlmapiAccuracyTestHarness): diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index d849cc496d61..e6cb5ea32ccf 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -668,6 +668,7 @@ accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-c accuracy/test_llm_api_pytorch.py::TestMiniMaxM2_5::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] accuracy/test_llm_api_pytorch.py::TestMiniMaxM3::test_auto_dtype[tp_size=8-ep_size=8] TIMEOUT (180) accuracy/test_llm_api_pytorch.py::TestMiniMaxM3::test_mxfp8[tp_size=8-ep_size=8] TIMEOUT (180) +accuracy/test_llm_api_pytorch.py::TestMiniMaxM3::test_mxfp8_piecewise_cuda_graph[tp_size=8-ep_size=8] TIMEOUT (180) accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm]