@@ -88,10 +88,18 @@ void main() {
88
88
ipos[i] = pos[i] * stride - padding;
89
89
}
90
90
91
- vec4 sum[TILE_SIZE_X * TILE_SIZE_Y];
92
- sum[0 ] = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
93
- for (int i = 1 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
94
- sum[i] = sum[0 ];
91
+ // Final output array where each element is a tensor value.
92
+ // Tuple of consecutive 4 elements represents a single output texel.
93
+ float sum[TILE_SIZE_X * TILE_SIZE_Y * 4 ];
94
+
95
+ const vec4 bias = texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
96
+
97
+ // Initialize the output array with the bias value
98
+ for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y * 4 ; i += 4 ) {
99
+ sum[i] = bias.x;
100
+ sum[i + 1 ] = bias.y;
101
+ sum[i + 2 ] = bias.z;
102
+ sum[i + 3 ] = bias.w;
95
103
}
96
104
97
105
int z4 = 0 ;
@@ -100,14 +108,26 @@ void main() {
100
108
// During prepacking, the weight tensor has been permuted so that the
101
109
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
102
110
// the z-axis.
103
- const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (0 , 0 ));
104
- const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (1 , 0 ));
105
- const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (2 , 0 ));
106
- const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2 (z, gpos.z), 0 , ivec2 (3 , 0 ));
111
+ float kernel_values[4 * 4 ]; // 4 channels, 4 elements per channel
112
+
113
+ // Load kernel values from texels to array
114
+ for (int i = 0 ; i < 4 ; ++ i) {
115
+ const vec4 k_tex = texelFetch(t_kernel, ivec2 (z + i, gpos.z), 0 );
116
+ kernel_values[i * 4 + 0 ] = k_tex.x;
117
+ kernel_values[i * 4 + 1 ] = k_tex.y;
118
+ kernel_values[i * 4 + 2 ] = k_tex.z;
119
+ kernel_values[i * 4 + 3 ] = k_tex.w;
120
+ }
107
121
108
- #pragma unroll
109
122
for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
110
123
const vec4 in_tex = texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
124
+ // Load the input texel into an array
125
+ float tex_values[4 ];
126
+ tex_values[0 ] = in_tex.x;
127
+ tex_values[1 ] = in_tex.y;
128
+ tex_values[2 ] = in_tex.z;
129
+ tex_values[3 ] = in_tex.w;
130
+
111
131
// For 2x2 tile size algorithm works as follows.
112
132
// To explain the calculations below, the contents of one in_tex and the
113
133
// group of 4 texels loaded from t_kernel are shown:
@@ -141,18 +161,20 @@ void main() {
141
161
//
142
162
// which is what is expressed in the following calculations. This is done
143
163
// for each output position.
144
- sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
145
- sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
146
- sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
147
- sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
164
+ for (int j = 0 ; j < 4 ; ++ j) {
165
+ sum[i * 4 + j] = tex_values[0 ] * kernel_values[0 + j] + sum[i * 4 + j];
166
+ sum[i * 4 + j] = tex_values[1 ] * kernel_values[4 + j] + sum[i * 4 + j];
167
+ sum[i * 4 + j] = tex_values[2 ] * kernel_values[8 + j] + sum[i * 4 + j];
168
+ sum[i * 4 + j] = tex_values[3 ] * kernel_values[12 + j] + sum[i * 4 + j];
169
+ }
148
170
}
149
171
}
150
172
151
173
for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
152
174
const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
153
175
const ivec3 pos = pos_shared[offset_pos_index(index)];
154
176
if (all (lessThan (pos, out_limits.xyz))) {
155
- imageStore(t_out, pos, op(sum[i] , out_min, out_max));
177
+ imageStore(t_out, pos, op(vec4 ( sum[i * 4 ], sum[i * 4 + 1 ], sum[i * 4 + 2 ], sum[i * 4 + 3 ]) , out_min, out_max));
156
178
}
157
179
}
158
180
}
0 commit comments