Skip to content

Commit 0c00b49

Browse files
[ET-VK] De vectorise conv2d pw shader to improve perf. (#11182)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11108 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/89/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/89/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/89/orig @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 196bc55 commit 0c00b49

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ void main() {
8888
ipos[i] = pos[i] * stride - padding;
8989
}
9090

91-
vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
92-
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
93-
for (int i = 1; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
94-
sum[i] = sum[0];
91+
// Final output array where each element is a tensor value.
92+
// Tuple of consecutive 4 elements represents a single output texel.
93+
float sum[TILE_SIZE_X * TILE_SIZE_Y * 4];
94+
95+
const vec4 bias = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
96+
97+
// Initialize the output array with the bias value
98+
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i += 4) {
99+
sum[i] = bias.x;
100+
sum[i + 1] = bias.y;
101+
sum[i + 2] = bias.z;
102+
sum[i + 3] = bias.w;
95103
}
96104

97105
int z4 = 0;
@@ -100,14 +108,26 @@ void main() {
100108
// During prepacking, the weight tensor has been permuted so that the
101109
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
102110
// the z-axis.
103-
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
104-
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
105-
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
106-
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));
111+
float kernel_values[4 * 4]; // 4 channels, 4 elements per channel
112+
113+
// Load kernel values from texels to array
114+
for (int i = 0; i < 4; ++i) {
115+
const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, gpos.z), 0);
116+
kernel_values[i * 4 + 0] = k_tex.x;
117+
kernel_values[i * 4 + 1] = k_tex.y;
118+
kernel_values[i * 4 + 2] = k_tex.z;
119+
kernel_values[i * 4 + 3] = k_tex.w;
120+
}
107121

108-
#pragma unroll
109122
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
110123
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
124+
// Load the input texel into an array
125+
float tex_values[4];
126+
tex_values[0] = in_tex.x;
127+
tex_values[1] = in_tex.y;
128+
tex_values[2] = in_tex.z;
129+
tex_values[3] = in_tex.w;
130+
111131
// For 2x2 tile size algorithm works as follows.
112132
// To explain the calculations below, the contents of one in_tex and the
113133
// group of 4 texels loaded from t_kernel are shown:
@@ -141,18 +161,20 @@ void main() {
141161
//
142162
// which is what is expressed in the following calculations. This is done
143163
// for each output position.
144-
sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
145-
sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
146-
sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
147-
sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
164+
for (int j = 0; j < 4; ++j) {
165+
sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j];
166+
sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j];
167+
sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j];
168+
sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j];
169+
}
148170
}
149171
}
150172

151173
for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) {
152174
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
153175
const ivec3 pos = pos_shared[offset_pos_index(index)];
154176
if (all(lessThan(pos, out_limits.xyz))) {
155-
imageStore(t_out, pos, op(sum[i], out_min, out_max));
177+
imageStore(t_out, pos, op(vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]), out_min, out_max));
156178
}
157179
}
158180
}

0 commit comments

Comments
 (0)