Skip to content

Commit 7cc93cf

Browse files
authored
[webgpu] Apply Flash Attention if sliding window exceeds KV cache length (#25594)
### Description <!-- Describe your changes. --> #25372 adds sliding window support for Group Query Attention, disabling Flash Attention as it's not yet supported. This PR adds a check for the sliding window and applies Flash Attention when the window size exceeds the KV cache length or total sequence length. ### Motivation and Context See above.
1 parent a120b4b commit 7cc93cf

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,12 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
198198
Tensor* present_value = context.Output(2, present_kv_shape);
199199
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
200200

201+
ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length.");
202+
// Use a sliding window if the total sequence exceeds the window's length.
203+
bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_);
201204
if (!do_rotary_ &&
202205
head_sink == nullptr && !use_smooth_softmax_ &&
203-
local_window_size_ == -1 &&
206+
!use_sliding_window &&
204207
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
205208
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
206209
present_value, parameters, context);

0 commit comments

Comments
 (0)