Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 76 additions & 44 deletions tensorrt_llm/_torch/attention_backend/sparse/minimax_m3/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def _sparse_gqa_masked(
max_k: int,
sm_scale: float,
causal: bool,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Vectorized sparse GQA: mask non-selected blocks to -inf in QK.

Expand All @@ -601,7 +602,9 @@ def _sparse_gqa_masked(
the algorithm is CUDA-graph safe. ``num_q_heads / num_kv_heads`` Q
heads share the same block_mask within a GQA group.

Returns ``[total_q, num_q_heads, head_dim]``.
Returns ``[total_q, num_q_heads * head_dim]``. The attention result
accumulates in FP32; when ``output`` is supplied, the final dtype
conversion writes directly into that tensor.

The per-Q FP32 expansion of ``k_padded``/``v_padded`` is computed
in slabs of ``chunk_q`` rows so the peak working set scales with
Expand Down Expand Up @@ -637,44 +640,51 @@ def _sparse_gqa_masked(
attended = block_mask_per_pos & valid_pos_per_kv # [total_q, num_kv_heads, max_k]
row_has_any = attended.any(dim=-1, keepdim=True) # [total_q, num_kv_heads, 1]

# Output buffer in FP32 to preserve precision until the final cast.
# Preserve FP32 accumulation until the final cast. When a preallocated
# output is supplied this avoids allocating a q.dtype intermediate and
# then copying that intermediate into the custom op's output tensor.
o = torch.empty((total_q, num_kv_heads, g, head_dim), dtype=torch.float32, device=q.device)

if total_q == 0 or max_k == 0:
return o.view(total_q, num_q_heads, head_dim).to(q.dtype)

chunk_q = _compute_sparse_gqa_chunk_q(total_q, max_k, num_kv_heads, head_dim, g)

for start in range(0, total_q, chunk_q):
end = min(start + chunk_q, total_q)
q_batch_row_chunk = q_batch_row[start:end]
# Per-Q K/V matrices for the chunk only — the dominant FP32 slab.
k_per_q_chunk = k_padded.to(torch.float32).index_select(
0, q_batch_row_chunk
) # [chunk, max_k, num_kv_heads, head_dim]
v_per_q_chunk = v_padded.to(torch.float32).index_select(0, q_batch_row_chunk)
q_grp_chunk = q_grp[start:end]
attended_chunk = attended[start:end]
row_has_any_chunk = row_has_any[start:end]

# qk: [chunk, num_kv_heads, g, max_k]
qk_chunk = torch.einsum("ihgd,iqhd->ihgq", q_grp_chunk, k_per_q_chunk) * sm_scale
qk_chunk = qk_chunk.masked_fill(~attended_chunk.unsqueeze(2), float("-inf"))

# Apply the mask + softmax via the OpenAI Triton kernel
# ``_sparse_softmax_kernel`` which honors the per-position
# ``attended`` mask, computes softmax in fp32, and folds in the
# all-False-row fix so captured graphs never produce NaN.
attn_chunk = triton_sparse_softmax(qk_chunk, attended_chunk)

# o_chunk: [chunk, num_kv_heads, g, head_dim]
o_chunk = torch.einsum("ihgq,iqhd->ihgd", attn_chunk, v_per_q_chunk)
# Zero out rows that had no valid positions.
keep_chunk = row_has_any_chunk.squeeze(-1) # [chunk, num_kv_heads]
o_chunk = o_chunk * keep_chunk.unsqueeze(-1).unsqueeze(-1)
o[start:end] = o_chunk
if total_q > 0 and max_k > 0:
chunk_q = _compute_sparse_gqa_chunk_q(total_q, max_k, num_kv_heads, head_dim, g)

return o.view(total_q, num_q_heads, head_dim).to(q.dtype)
for start in range(0, total_q, chunk_q):
end = min(start + chunk_q, total_q)
q_batch_row_chunk = q_batch_row[start:end]
# Per-Q K/V matrices for the chunk only — the dominant FP32 slab.
k_per_q_chunk = k_padded.to(torch.float32).index_select(
0, q_batch_row_chunk
) # [chunk, max_k, num_kv_heads, head_dim]
v_per_q_chunk = v_padded.to(torch.float32).index_select(0, q_batch_row_chunk)
q_grp_chunk = q_grp[start:end]
attended_chunk = attended[start:end]
row_has_any_chunk = row_has_any[start:end]

# qk: [chunk, num_kv_heads, g, max_k]
qk_chunk = torch.einsum("ihgd,iqhd->ihgq", q_grp_chunk, k_per_q_chunk) * sm_scale
qk_chunk = qk_chunk.masked_fill(~attended_chunk.unsqueeze(2), float("-inf"))

# Apply the mask + softmax via the OpenAI Triton kernel
# ``_sparse_softmax_kernel`` which honors the per-position
# ``attended`` mask, computes softmax in fp32, and folds in the
# all-False-row fix so captured graphs never produce NaN.
attn_chunk = triton_sparse_softmax(qk_chunk, attended_chunk)

# o_chunk: [chunk, num_kv_heads, g, head_dim]
o_chunk = torch.einsum("ihgq,iqhd->ihgd", attn_chunk, v_per_q_chunk)
# Zero out rows that had no valid positions.
keep_chunk = row_has_any_chunk.squeeze(-1) # [chunk, num_kv_heads]
o_chunk = o_chunk * keep_chunk.unsqueeze(-1).unsqueeze(-1)
o[start:end] = o_chunk

o_flat = o.view(total_q, num_q_heads * head_dim)
if output is None:
return o_flat.to(q.dtype)
expected_shape = (total_q, num_q_heads * head_dim)
if tuple(output.shape) != expected_shape:
raise ValueError(f"output must have shape {expected_shape}, got {tuple(output.shape)}")
output.copy_(o_flat)
return output


# ---------------------------------------------------------------------------
Expand All @@ -695,6 +705,7 @@ def minimax_m3_sparse_decode(
disable_index_value: bool,
sm_scale: Optional[float] = None,
idx_sm_scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""MiniMax-M3 sparse decode (CUDA-graph safe).

Expand Down Expand Up @@ -759,7 +770,7 @@ def minimax_m3_sparse_decode(
idx_sm_scale=idx_sm_scale,
causal=False,
)
o_per_head = _sparse_gqa_masked(
o = _sparse_gqa_masked(
q,
k_padded,
v_padded,
Expand All @@ -771,8 +782,8 @@ def minimax_m3_sparse_decode(
max_k=max_k,
sm_scale=sm_scale,
causal=False,
output=output,
)
o = o_per_head.reshape(batch, config.num_q_heads * config.head_dim).contiguous()
return idx_o, o


Expand All @@ -789,6 +800,7 @@ def minimax_m3_sparse_prefill(
disable_index_value: bool,
sm_scale: Optional[float] = None,
idx_sm_scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""MiniMax-M3 sparse prefill / chunked-extend (CUDA-graph safe).

Expand Down Expand Up @@ -856,7 +868,7 @@ def minimax_m3_sparse_prefill(
idx_sm_scale=idx_sm_scale,
causal=True,
)
o_per_head = _sparse_gqa_masked(
o = _sparse_gqa_masked(
q,
k_padded,
v_padded,
Expand All @@ -868,9 +880,8 @@ def minimax_m3_sparse_prefill(
max_k=max_k,
sm_scale=sm_scale,
causal=True,
output=output,
)
total_q = int(q.shape[0])
o = o_per_head.reshape(total_q, config.num_q_heads * config.head_dim).contiguous()
return idx_o, o


Expand Down Expand Up @@ -997,7 +1008,12 @@ def get_minimax_m3_attention_backend_cls():
Deferring the :class:`AttentionBackend` import keeps the algorithm
module usable from test paths that do not need the runtime backend.
"""
from ...interface import AttentionBackend, AttentionMetadata
from ...interface import (
AttentionBackend,
AttentionForwardArgs,
AttentionMetadata,
merge_attention_forward_args,
)

metadata_cls = get_minimax_m3_attention_metadata_cls()

Expand Down Expand Up @@ -1082,6 +1098,7 @@ def forward_sparse(
m3_metadata: "MiniMaxM3SparseAttentionMetadata",
sm_scale: Optional[float] = None,
idx_sm_scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Execute the MiniMax-M3 sparse path end-to-end.

Expand All @@ -1103,6 +1120,8 @@ def forward_sparse(
token's K/V/idx_K to.
``m3_metadata`` : populated
:class:`MiniMaxM3SparseAttentionMetadata`.
``output`` : optional preallocated final output,
``[num_tokens, num_q_heads * head_dim]``.

Returns ``[num_tokens, num_q_heads * head_dim]``.
"""
Expand Down Expand Up @@ -1187,6 +1206,7 @@ def forward_sparse(
disable_index_value=self.disable_index_value,
sm_scale=sm_scale,
idx_sm_scale=idx_sm_scale,
output=output,
)
else:
_, o = minimax_m3_sparse_decode(
Expand All @@ -1201,6 +1221,7 @@ def forward_sparse(
disable_index_value=self.disable_index_value,
sm_scale=sm_scale,
idx_sm_scale=idx_sm_scale,
output=output,
)
return o

Expand All @@ -1210,8 +1231,9 @@ def forward(
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
metadata=None,
forward_args=None,
forward_args: Optional[AttentionForwardArgs] = None,
*,
output: Optional[torch.Tensor] = None,
idx_q: Optional[torch.Tensor] = None,
idx_k: Optional[torch.Tensor] = None,
idx_v: Optional[torch.Tensor] = None,
Expand All @@ -1223,7 +1245,7 @@ def forward(
m3_metadata: Optional["MiniMaxM3SparseAttentionMetadata"] = None,
sm_scale: Optional[float] = None,
idx_sm_scale: Optional[float] = None,
**_unused,
**kwargs,
) -> torch.Tensor:
"""Standard ``AttentionBackend.forward`` entry point.

Expand All @@ -1237,6 +1259,15 @@ def forward(
generic AttentionBackend dispatch site cannot drive this
backend without supplying the index branch.
"""
forward_args = merge_attention_forward_args(forward_args, kwargs)
if (
output is not None
and forward_args.output is not None
and output is not forward_args.output
):
raise ValueError("output was supplied both directly and through forward_args")
if output is None:
output = forward_args.output
if idx_q is None or idx_k is None or m3_metadata is None:
raise NotImplementedError(
f"MiniMaxM3SparseRuntimeBackend.forward (layer "
Expand Down Expand Up @@ -1277,6 +1308,7 @@ def forward(
m3_metadata=m3_metadata,
sm_scale=sm_scale,
idx_sm_scale=idx_sm_scale,
output=output,
)

return MiniMaxM3SparseRuntimeBackend
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from collections import OrderedDict
from typing import List, Optional
from weakref import WeakSet

import torch
import torch._inductor.config as inductor_config
Expand All @@ -19,7 +20,7 @@
from .patterns.ar_residual_norm import register_ar_fusions
from .patterns.residual_add_norm import (register_add_norm,
register_add_norm_quant)
from .piecewise_optimizer import piecewise_optimizer
from .piecewise_optimizer import PiecewiseRunner, piecewise_optimizer
from .recover_pass import recover_pass
from .remove_copy_pass import remove_copy_for_mutates_args

Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
self.enable_inductor = enable_inductor
self.capture_num_tokens = sorted(capture_num_tokens or [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self._piecewise_runners: WeakSet[PiecewiseRunner] = WeakSet()
self.no_optimization = False
self.num_streams = max_num_streams
self.events = Backend.Events()
Expand Down Expand Up @@ -107,6 +109,10 @@ def generate_events(self, num_events: int):
torch.cuda.Event() for _ in range(num_events - len(self.events))
]

def clear_piecewise_cuda_graphs(self):
for runner in list(self._piecewise_runners):
runner.clear_cuda_graphs()

def optimize(
self,
gm: GraphModule,
Expand Down Expand Up @@ -141,7 +147,7 @@ def optimize(
gm.recompile()

if self.piecewise_cuda_graph:
gm, num_events = piecewise_optimizer(
gm, num_events, runners = piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
Expand All @@ -150,6 +156,7 @@ def optimize(
self._graph_pool_handle,
self.num_streams,
)
self._piecewise_runners.update(runners)
self.generate_events(num_events)
return gm
elif self.enable_inductor:
Expand Down
Loading
Loading