-
-
Notifications
You must be signed in to change notification settings - Fork 12.9k
[Attention][WIP] FA4 integration #32974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Attention][WIP] FA4 integration #32974
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully integrates FlashAttention 4 (FA4) by updating CMake configurations, adding necessary dependencies, and extending the attention configuration and version detection logic. The new vllm/vllm_flash_attn/flash_attn_interface.py file correctly centralizes the logic for selecting and using different FlashAttention versions (FA2, FA3, FA4).
However, there's a critical issue with the newly added file vllm/third_party/flashmla/flash_mla_interface.py. This file appears to duplicate the general FlashAttention (FA4) functions (_flash_attn_varlen_forward, flash_attn_varlen_func, etc.) which are already correctly implemented and managed in vllm/vllm_flash_attn/flash_attn_interface.py. This leads to code duplication and potential confusion regarding which implementation is canonical. The vllm/third_party/flashmla/flash_mla_interface.py file should ideally only contain FlashMLA-specific functionalities.
| def _flash_attn_varlen_forward( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_qo: torch.Tensor, | ||
| cu_seqlens_kv: torch.Tensor, | ||
| max_seqlen_qo: int, | ||
| max_seqlen_kv: int, | ||
| out: Optional[torch.Tensor] = None, | ||
| lse: Optional[torch.Tensor] = None, | ||
| causal: bool = False, | ||
| softmax_scale: Optional[float] = None, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| qo_total_len, num_qo_heads, head_dim_qk = q.shape | ||
| kv_total_len, num_kv_heads, head_dim_vo = v.shape | ||
|
|
||
| mask_mode_code = 1 if causal else 0 | ||
| if softmax_scale is None: | ||
| softmax_scale = head_dim_qk ** (-0.5) | ||
|
|
||
| if out is None: | ||
| out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) | ||
| if lse is None: | ||
| # Make lse contiguous on seqlen dim | ||
| lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T | ||
|
|
||
| workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) | ||
| flash_mla_cuda.dense_prefill_fwd( | ||
| workspace_buffer, | ||
| q, | ||
| k, | ||
| v, | ||
| cu_seqlens_qo, | ||
| cu_seqlens_kv, | ||
| out, | ||
| lse, | ||
| mask_mode_code, | ||
| softmax_scale, | ||
| max_seqlen_qo, | ||
| max_seqlen_kv, | ||
| is_varlen, | ||
| ) | ||
|
|
||
| return out, lse | ||
|
|
||
|
|
||
| def _flash_attn_varlen_backward( | ||
| do: torch.Tensor, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| out: torch.Tensor, | ||
| lse: torch.Tensor, | ||
| cu_seqlens_qo: torch.Tensor, | ||
| cu_seqlens_kv: torch.Tensor, | ||
| max_seqlen_qo: int, | ||
| max_seqlen_kv: int, | ||
| dq: Optional[torch.Tensor] = None, | ||
| dk: Optional[torch.Tensor] = None, | ||
| dv: Optional[torch.Tensor] = None, | ||
| causal: bool = False, | ||
| softmax_scale: Optional[float] = None, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| qo_total_len, num_qo_heads, head_dim_qk = q.shape | ||
| kv_total_len, num_kv_heads, head_dim_vo = v.shape | ||
|
|
||
| # TODO: fix bwd GQA | ||
| if num_qo_heads != num_kv_heads: | ||
| raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") | ||
|
|
||
| mask_mode_code = 1 if causal else 0 | ||
| if softmax_scale is None: | ||
| softmax_scale = head_dim_qk ** (-0.5) | ||
|
|
||
| if dq is None: | ||
| dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) | ||
| if dk is None: | ||
| dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) | ||
| if dv is None: | ||
| dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) | ||
|
|
||
| max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 | ||
| bs = cu_seqlens_qo.shape[0] - 1 | ||
| workspace_bytes = 0 | ||
| workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc | ||
| workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse | ||
| if num_qo_heads != num_kv_heads: | ||
| workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc | ||
| workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) | ||
| flash_mla_cuda.dense_prefill_bwd( | ||
| workspace_buffer, | ||
| do, | ||
| q, | ||
| k, | ||
| v, | ||
| out, | ||
| lse, | ||
| cu_seqlens_qo, | ||
| cu_seqlens_kv, | ||
| dq, | ||
| dk, | ||
| dv, | ||
| mask_mode_code, | ||
| softmax_scale, | ||
| max_seqlen_qo, | ||
| max_seqlen_kv, | ||
| is_varlen, | ||
| ) | ||
|
|
||
| return dq, dk, dv | ||
|
|
||
|
|
||
| class FlashAttnVarlenFunc(torch.autograd.Function): | ||
| def forward( | ||
| ctx, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_qo: torch.Tensor, | ||
| cu_seqlens_kv: torch.Tensor, | ||
| max_seqlen_qo: int, | ||
| max_seqlen_kv: int, | ||
| causal: bool = False, | ||
| softmax_scale: Optional[float] = None, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| out, lse = _flash_attn_varlen_forward( | ||
| q, k, v, | ||
| cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, | ||
| causal=causal, softmax_scale=softmax_scale, | ||
| is_varlen=is_varlen, | ||
| ) | ||
| ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) | ||
| ctx.max_seqlen_qo = max_seqlen_qo | ||
| ctx.max_seqlen_kv = max_seqlen_kv | ||
| ctx.causal = causal | ||
| ctx.softmax_scale = softmax_scale | ||
| ctx.is_varlen = is_varlen | ||
| return out, lse | ||
|
|
||
| def backward( | ||
| ctx, | ||
| do: torch.Tensor, | ||
| dlse: torch.Tensor, | ||
| ): | ||
| del dlse # LSE doesn't support backward currently | ||
| q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors | ||
| dq, dk, dv = _flash_attn_varlen_backward( | ||
| do, q, k, v, out, lse, | ||
| cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, | ||
| causal=ctx.causal, softmax_scale=ctx.softmax_scale, | ||
| is_varlen=ctx.is_varlen, | ||
| ) | ||
| return dq, dk, dv, None, None, None, None, None, None, None | ||
|
|
||
|
|
||
| def flash_attn_varlen_func( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_qo: torch.Tensor, | ||
| cu_seqlens_kv: torch.Tensor, | ||
| max_seqlen_qo: int, | ||
| max_seqlen_kv: int, | ||
| dropout_p: float = 0.0, | ||
| softmax_scale: Optional[float] = None, | ||
| causal: bool = False, | ||
| deterministic: bool = False, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| assert dropout_p == 0.0 | ||
| assert not deterministic | ||
| return FlashAttnVarlenFunc.apply( | ||
| q, k, v, | ||
| cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, | ||
| causal, softmax_scale, is_varlen, | ||
| ) | ||
|
|
||
|
|
||
| def flash_attn_varlen_qkvpacked_func( | ||
| qkv: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| max_seqlen: int, | ||
| head_dim_qk: int, | ||
| dropout_p: float = 0.0, | ||
| softmax_scale: Optional[float] = None, | ||
| causal: bool = False, | ||
| deterministic: bool = False, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| assert dropout_p == 0.0 | ||
| assert not deterministic | ||
| return FlashAttnVarlenFunc.apply( | ||
| qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], | ||
| cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, | ||
| causal, softmax_scale, is_varlen, | ||
| ) | ||
|
|
||
|
|
||
| def flash_attn_varlen_kvpacked_func( | ||
| q: torch.Tensor, | ||
| kv: torch.Tensor, | ||
| cu_seqlens_qo: torch.Tensor, | ||
| cu_seqlens_kv: torch.Tensor, | ||
| max_seqlen_qo: int, | ||
| max_seqlen_kv: int, | ||
| head_dim_qk: int, | ||
| dropout_p: float = 0.0, | ||
| softmax_scale: Optional[float] = None, | ||
| causal: bool = False, | ||
| deterministic: bool = False, | ||
| is_varlen: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| assert dropout_p == 0.0 | ||
| assert not deterministic | ||
| return FlashAttnVarlenFunc.apply( | ||
| q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], | ||
| cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, | ||
| causal, softmax_scale, is_varlen, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file, flash_mla_interface.py, is intended for FlashMLA-specific functions. However, it contains _flash_attn_varlen_forward, _flash_attn_varlen_backward, FlashAttnVarlenFunc, and related flash_attn_varlen_func wrappers. These functions are general FlashAttention implementations (specifically for FA4 in this context) and are already correctly defined and managed in vllm/vllm_flash_attn/flash_attn_interface.py. This duplication creates an architectural inconsistency and could lead to confusion or maintenance issues. These general FlashAttention functions should be removed from this file.
| file(READ \${SRC_FILE} FILE_CONTENTS) | ||
| string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") | ||
| file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") | ||
| endforeach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string replacement flash_attn.cute with vllm.vllm_flash_attn.cute is a brittle approach to modifying source code. While it might be necessary for integrating an external library, it can break if the upstream library changes its internal import paths in a way that is not caught by this simple string replacement. Consider if there's a more robust way to handle this, perhaps by patching the source or using a more targeted transformation if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| fa_version = 4 | ||
| else: | ||
| # Fallback to FA2 | ||
| fa_version = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FA4 selection causes runtime failure with ALiBi models
High Severity
On Blackwell (SM100+), FA4 is now selected by default, but FA4 doesn't support ALiBi. Models using ALiBi (like MPT, BLOOM) will hit assert alibi_slopes is None, "Alibi is not supported in FA4" at runtime. There's fallback logic for FA3+ALiBi (lines 83-87) but no corresponding fallback for FA4+ALiBi. This is a regression since FA2 (which supports ALiBi) was previously used on Blackwell.
Integrate upstream FA4