diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl index 4b18abbb1c5..1a2c257baec 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl @@ -42,47 +42,25 @@ const lowp int packed_dim = unhash_packed_dim(t_layout); * Extends sign of int8 */ int extend_sign(int x) { - if (x >> 7 == 1) { - return x | 0xFFFFFF00; - } - return x; + return x | mix(0, 0xFFFFFF00, x >= (1 << 7)); } ivec4 read_texel(ivec4 tidx) { - ivec4 tidx_to_use = tidx; - ivec4 sizes_to_use = sizes; - int packed_dim_to_use = packed_dim; - if (transpose_hw == 1) { - sizes_to_use.xy = sizes_to_use.yx; - tidx_to_use.xy = tidx.yx; - - if (packed_dim == 1) { - packed_dim_to_use = 0; - } - if (packed_dim == 0) { - packed_dim_to_use = 1; - } - } + const ivec4 tidx_to_use = ivec4(mix(tidx.xy, tidx.yx, bvec2(transpose_hw == 1)), tidx.zw); + const ivec4 sizes_to_use = ivec4(mix(sizes.xy, sizes.yx, bvec2(transpose_hw == 1)), sizes.zw); + const int packed_dim_to_use = mix(packed_dim, packed_dim ^ transpose_hw, packed_dim < 2); const ivec4 buf_indices = tidx_to_nchwi( tidx_to_use, sizes_to_use, packed_dim_to_use); - int shift = (1 << 8) - 1; - ivec4 masks; - // Masks used to unpack 4x 8-bit values from a 32 bit integer. Note that - // little endian is assumed, as most processors use little endian. Thus the - // most significant bytes correspond to the "latter" packed values. - masks.x = shift << (8 * (buf_indices.x % 4)); - masks.y = shift << (8 * (buf_indices.y % 4)); - masks.z = shift << (8 * (buf_indices.z % 4)); - masks.w = shift << (8 * (buf_indices.w % 4)); + const int mask = (1 << 8) - 1; ivec4 out_tex = ivec4(0); [[unroll]] for (int i = 0; i < 4; ++i) { if (tidx[packed_dim] + i < sizes[packed_dim]) { - int in_texel = nchw_in[buf_indices[i] / 4]; - int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4)); + const int in_texel = nchw_in[buf_indices[i] >> 2]; + int extracted_val = (in_texel >> (8 * (buf_indices[i] & 3))) & mask; extracted_val = extend_sign(extracted_val); out_tex[i] = extracted_val; }