File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
onnxruntime/contrib_ops/webgpu/bert Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -250,7 +250,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
250250 if (has_sliding_window) {
251251 // Sliding window
252252 shader.MainFunctionBody ()
253- << " let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > u32( uniforms.local_window_size) + 1;\n "
253+ << " let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n "
254254 << " let start_offset = select(0, seq_causal_length - u32(uniforms.local_window_size), should_apply_local_window);\n "
255255 << " let effective_seq_length = select(seq_causal_length, u32(uniforms.local_window_size), should_apply_local_window);\n " ;
256256 } else {
@@ -288,7 +288,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
288288 << " var sum_vector = f32_val_t(0);\n "
289289 << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n "
290290 << " let actual_pos = local_offset + i + start_offset;\n "
291- << " if (!should_apply_local_window ||actual_pos < seq_causal_length) {\n "
291+ << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n "
292292 << " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n "
293293 << " }\n "
294294 << " }\n "
Original file line number Diff line number Diff line change @@ -82,7 +82,7 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
8282 {" total_sequence_length_comp" , ProgramUniformVariableDataType::Uint32},
8383 {" elements_per_thread" , ProgramUniformVariableDataType::Uint32},
8484 {" is_first_prompt" , ProgramUniformVariableDataType::Uint32},
85- {" local_window_size" , ProgramUniformVariableDataType::Int32 });
85+ {" local_window_size" , ProgramUniformVariableDataType::Uint32 });
8686
8787 private:
8888 int work_group_size_;
You can’t perform that action at this time.
0 commit comments