Skip to content

Commit f31c477

Browse files
committed
review feedback
1 parent adfe338 commit f31c477

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
250250
if (has_sliding_window) {
251251
// Sliding window
252252
shader.MainFunctionBody()
253-
<< "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > u32(uniforms.local_window_size) + 1;\n"
253+
<< "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n"
254254
<< "let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n"
255255
<< "let effective_seq_length = select(seq_causal_length, u32(uniforms.local_window_size), should_apply_local_window);\n";
256256
} else {
@@ -288,7 +288,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
288288
<< "var sum_vector = f32_val_t(0);\n"
289289
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n"
290290
<< " let actual_pos = local_offset + i + start_offset;\n"
291-
<< " if (!should_apply_local_window ||actual_pos < seq_causal_length) {\n"
291+
<< " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n"
292292
<< " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n"
293293
<< " }\n"
294294
<< "}\n"

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
8282
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
8383
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
8484
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
85-
{"local_window_size", ProgramUniformVariableDataType::Int32});
85+
{"local_window_size", ProgramUniformVariableDataType::Uint32});
8686

8787
private:
8888
int work_group_size_;

0 commit comments

Comments
 (0)