From adfe3389eb04a83755ae03b1c87405bb56ba72bb Mon Sep 17 00:00:00 2001 From: gs Date: Fri, 11 Jul 2025 12:28:55 -0700 Subject: [PATCH 1/5] add sliding window support for webgpu gqa --- .../contrib_ops/webgpu/bert/attention.cc | 84 ++++++++++++++----- .../contrib_ops/webgpu/bert/attention.h | 8 +- .../webgpu/bert/attention_common.h | 2 +- .../webgpu/bert/group_query_attention.cc | 5 +- 4 files changed, 73 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 55bcf42f2f04b..c48703230ea0f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -98,6 +98,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var tileQ: array;\n" << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n" << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n" @@ -224,6 +225,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + bool has_sliding_window = local_window_size_ != -1; + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } @@ -241,15 +244,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str() - << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" << "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" - << "}\n" - << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n"; + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"; + if (has_sliding_window) { + // Sliding window + shader.MainFunctionBody() + << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > u32(uniforms.local_window_size) + 1;\n" + << "let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n" + << "let effective_seq_length = select(seq_causal_length, u32(uniforms.local_window_size), should_apply_local_window);\n"; + } else { + // No sliding window: we keep the code for sliding window in the shader but + // using const for start_offset and should_apply_local_window will make the compiler optimize it out. + shader.MainFunctionBody() + << "const start_offset = 0;\n" + << "const should_apply_local_window = false;\n" + << "let effective_seq_length = seq_causal_length;\n"; + } + shader.MainFunctionBody() + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i + start_offset]), thread_max_vector);\n" + << " }\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n"; if (has_head_sink_) { // Handle head sink @@ -265,8 +286,11 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " max_value = max(thread_max[i], max_value);\n" << "}\n" << "var sum_vector = f32_val_t(0);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (!should_apply_local_window ||actual_pos < seq_causal_length) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n" + << " }\n" << "}\n" << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" << "workgroupBarrier();\n" @@ -282,15 +306,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.MainFunctionBody() << "if (sum == 0) {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (actual_pos < seq_causal_length) {\n" + << " x[offset + i + start_offset] = x_value_t(x_element_t(1.0)/x_element_t(effective_seq_length));\n" + << " }\n" << " }\n" << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " let pos = offset + i + start_offset;\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" + << " var f32input = f32_val_t(x[pos]);\n" + << " x[pos] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << " }\n" + << "}\n"; + + // zero out elements outsize the sliding window + shader.MainFunctionBody() << "if (should_apply_local_window) {\n" << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " var f32input = f32_val_t(x[offset + i]);\n" - << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " let global_pos = i + local_offset;\n" + << " if (global_pos < start_offset) {\n" + << " x[offset + i] = x_value_t(x_element_t(0));\n" + << " }\n" << " }\n" << "}\n"; + if (has_seqlen_k_) { shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" @@ -301,7 +343,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) { + const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink, int local_window_size) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; @@ -310,7 +352,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr}; + InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr, local_window_size}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -318,7 +360,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(work_group_size, use_smooth_softmax) + .CacheHint(work_group_size, use_smooth_softmax, local_window_size != -1) .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -327,7 +369,8 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, {static_cast(elementsPerThread)}, - {static_cast(is_first_prompt ? 1 : 0)}}); + {static_cast(is_first_prompt ? 1 : 0)}, + {static_cast(local_window_size)}}); return context.RunProgram(program); } @@ -467,7 +510,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, + const Tensor* seqlen_k, int local_window_size) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = @@ -481,7 +525,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink, local_window_size)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index e64ca3539c23d..864624a891e89 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -69,8 +69,8 @@ class AttentionProbsProgram final : public Program { class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink) - : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) { + InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink, int local_window_size) + : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink), local_window_size_(local_window_size) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -81,7 +81,8 @@ class InPlaceSoftmaxProgram final : public Program { {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"local_window_size", ProgramUniformVariableDataType::Int32}); private: int work_group_size_; @@ -89,6 +90,7 @@ class InPlaceSoftmaxProgram final : public Program { bool use_smooth_softmax_; bool has_seqlen_k_; bool has_head_sink_; + int local_window_size_; }; class VxAttentionScoreProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 9d4740ede7143..71161c120a306 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -124,7 +124,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, - const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr); + const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr, int local_window_size = -1); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f3334b13dc645..858673ad5b44a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -196,6 +196,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (!do_rotary_ && head_sink == nullptr && !use_smooth_softmax_ && + local_window_size_ == -1 && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); @@ -237,7 +238,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, - present_value, parameters, context, head_sink, seqlen_k); + present_value, parameters, context, head_sink, seqlen_k, local_window_size_); } TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, @@ -254,7 +255,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, - present_value, parameters, context, head_sink, seqlen_k); + present_value, parameters, context, head_sink, seqlen_k, local_window_size_); } } // namespace webgpu From f31c4777f8c859dd1b0ec24fe5b0205d5708c3a2 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 14 Jul 2025 17:02:17 -0700 Subject: [PATCH 2/5] review feedback --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 4 ++-- onnxruntime/contrib_ops/webgpu/bert/attention.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index c48703230ea0f..6f50058cdd6b9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -250,7 +250,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_sliding_window) { // Sliding window shader.MainFunctionBody() - << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > u32(uniforms.local_window_size) + 1;\n" + << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n" << "let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n" << "let effective_seq_length = select(seq_causal_length, u32(uniforms.local_window_size), should_apply_local_window);\n"; } else { @@ -288,7 +288,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var sum_vector = f32_val_t(0);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" << " let actual_pos = local_offset + i + start_offset;\n" - << " if (!should_apply_local_window ||actual_pos < seq_causal_length) {\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" << " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n" << " }\n" << "}\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 864624a891e89..3450705b04908 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -82,7 +82,7 @@ class InPlaceSoftmaxProgram final : public Program { {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, - {"local_window_size", ProgramUniformVariableDataType::Int32}); + {"local_window_size", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; From be8942026992818f06aed8c2f8f351953102db67 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 16 Jul 2025 12:29:50 -0700 Subject: [PATCH 3/5] Update onnxruntime/contrib_ops/webgpu/bert/attention.cc Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 6f50058cdd6b9..dca41c29dc750 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -252,7 +252,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n" << "let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n" - << "let effective_seq_length = select(seq_causal_length, u32(uniforms.local_window_size), should_apply_local_window);\n"; + << "let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);\n"; } else { // No sliding window: we keep the code for sliding window in the shader but // using const for start_offset and should_apply_local_window will make the compiler optimize it out. From 952495f91522ccb6ef1570407a943ba090e8afa2 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 16 Jul 2025 12:30:01 -0700 Subject: [PATCH 4/5] Update onnxruntime/contrib_ops/webgpu/bert/attention.cc Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index dca41c29dc750..59ff79c1e7dbe 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -370,7 +370,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(total_sequence_length_comp)}, {static_cast(elementsPerThread)}, {static_cast(is_first_prompt ? 1 : 0)}, - {static_cast(local_window_size)}}); + {static_cast(local_window_size)}}); return context.RunProgram(program); } From 8ff3fb285ff8a13729b795f2af3fd8ec92a5a3d9 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 16 Jul 2025 15:45:52 -0700 Subject: [PATCH 5/5] Update onnxruntime/contrib_ops/webgpu/bert/attention.cc Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 59ff79c1e7dbe..dbea308e0b08c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -251,7 +251,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { // Sliding window shader.MainFunctionBody() << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n" - << "let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n" + << "let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);\n" << "let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);\n"; } else { // No sliding window: we keep the code for sliding window in the shader but