Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

Integrate upstream FA4

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully integrates FlashAttention 4 (FA4) by updating CMake configurations, adding necessary dependencies, and extending the attention configuration and version detection logic. The new vllm/vllm_flash_attn/flash_attn_interface.py file correctly centralizes the logic for selecting and using different FlashAttention versions (FA2, FA3, FA4).

However, there's a critical issue with the newly added file vllm/third_party/flashmla/flash_mla_interface.py. This file appears to duplicate the general FlashAttention (FA4) functions (_flash_attn_varlen_forward, flash_attn_varlen_func, etc.) which are already correctly implemented and managed in vllm/vllm_flash_attn/flash_attn_interface.py. This leads to code duplication and potential confusion regarding which implementation is canonical. The vllm/third_party/flashmla/flash_mla_interface.py file should ideally only contain FlashMLA-specific functionalities.

Comment on lines 215 to 436
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape

mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)

if out is None:
out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype)
if lse is None:
# Make lse contiguous on seqlen dim
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_fwd(
workspace_buffer,
q,
k,
v,
cu_seqlens_qo,
cu_seqlens_kv,
out,
lse,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)

return out, lse


def _flash_attn_varlen_backward(
do: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
qo_total_len, num_qo_heads, head_dim_qk = q.shape
kv_total_len, num_kv_heads, head_dim_vo = v.shape

# TODO: fix bwd GQA
if num_qo_heads != num_kv_heads:
raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.")

mask_mode_code = 1 if causal else 0
if softmax_scale is None:
softmax_scale = head_dim_qk ** (-0.5)

if dq is None:
dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dk is None:
dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype)
if dv is None:
dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype)

max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0
workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_cuda.dense_prefill_bwd(
workspace_buffer,
do,
q,
k,
v,
out,
lse,
cu_seqlens_qo,
cu_seqlens_kv,
dq,
dk,
dv,
mask_mode_code,
softmax_scale,
max_seqlen_qo,
max_seqlen_kv,
is_varlen,
)

return dq, dk, dv


class FlashAttnVarlenFunc(torch.autograd.Function):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal=causal, softmax_scale=softmax_scale,
is_varlen=is_varlen,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv)
ctx.max_seqlen_qo = max_seqlen_qo
ctx.max_seqlen_kv = max_seqlen_kv
ctx.causal = causal
ctx.softmax_scale = softmax_scale
ctx.is_varlen = is_varlen
return out, lse

def backward(
ctx,
do: torch.Tensor,
dlse: torch.Tensor,
):
del dlse # LSE doesn't support backward currently
q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors
dq, dk, dv = _flash_attn_varlen_backward(
do, q, k, v, out, lse,
cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale,
is_varlen=ctx.is_varlen,
)
return dq, dk, dv, None, None, None, None, None, None, None


def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)


def flash_attn_varlen_qkvpacked_func(
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:],
cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
causal, softmax_scale, is_varlen,
)


def flash_attn_varlen_kvpacked_func(
q: torch.Tensor,
kv: torch.Tensor,
cu_seqlens_qo: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_qo: int,
max_seqlen_kv: int,
head_dim_qk: int,
dropout_p: float = 0.0,
softmax_scale: Optional[float] = None,
causal: bool = False,
deterministic: bool = False,
is_varlen: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert dropout_p == 0.0
assert not deterministic
return FlashAttnVarlenFunc.apply(
q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:],
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This file, flash_mla_interface.py, is intended for FlashMLA-specific functions. However, it contains _flash_attn_varlen_forward, _flash_attn_varlen_backward, FlashAttnVarlenFunc, and related flash_attn_varlen_func wrappers. These functions are general FlashAttention implementations (specifically for FA4 in this context) and are already correctly defined and managed in vllm/vllm_flash_attn/flash_attn_interface.py. This duplication creates an architectural inconsistency and could lead to confusion or maintenance issues. These general FlashAttention functions should be removed from this file.

Comment on lines +92 to +95
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The string replacement flash_attn.cute with vllm.vllm_flash_attn.cute is a brittle approach to modifying source code. While it might be necessary for integrating an external library, it can break if the upstream library changes its internal import paths in a way that is not caught by this simple string replacement. Consider if there's a more robust way to handle this, perhaps by patching the source or using a more targeted transformation if possible.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

fa_version = 4
else:
# Fallback to FA2
fa_version = 2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA4 selection causes runtime failure with ALiBi models

High Severity

On Blackwell (SM100+), FA4 is now selected by default, but FA4 doesn't support ALiBi. Models using ALiBi (like MPT, BLOOM) will hit assert alibi_slopes is None, "Alibi is not supported in FA4" at runtime. There's fallback logic for FA3+ALiBi (lines 83-87) but no corresponding fallback for FA4+ALiBi. This is a regression since FA2 (which supports ALiBi) was previously used on Blackwell.

Additional Locations (1)

Fix in Cursor Fix in Web

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant