-
Notifications
You must be signed in to change notification settings - Fork 44
Description
Target use case:
module {
func.func @main(%arg0: memref<32x2x192xbf16>, %arg1: memref<32x2x192xbf16>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : bf16
%block_id_y = gpu.block_id y
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [8, 4, 2, 6, 32] : memref<32x2x192xbf16> into memref<8x4x2x6x32xbf16>
%0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]>} : memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
%1 = math.absf %0 >} : vector<4x2x6x32xbf16>
%expand_shape_0 = memref.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [8, 4, 2, 6, 32] : memref<32x2x192xbf16> into memref<8x4x2x6x32xbf16>
vector.transfer_write %1, %expand_shape_0[%block_id_y, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]>} : vector<4x2x6x32xbf16>, memref<8x4x2x6x32xbf16>
return
}
}
Below is targeted code sequence. #xegpu.slice will be attached after the lowering, but put inside the code for illustration purpose.
// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]}>} : memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
#abcd = #xegpu.layout<{sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]
#a = #xegpu.slice<{#xegpu.slice<{ #xegpu.slice<{#abcd, 3}>, 2}, 1>}>
#b = #xegpu.slice<{#xegpu.slice<{ #xegpu.slice<{#abcd, 3}>, 2}, 0>}>
#c = #xegpu.slice<{#xegpu.slice<{ #xegpu.slice<{#abcd, 3}>, 1}, 0>}>
#d = #xegpu.slice<{#xegpu.slice<{ #xegpu.slice<{#abcd, 2}>, 1}, 0>}>
%6 = vector.step #a: vector<4xindex>
%7 = vector.step #b: vector<2xindex>
%8 = vector.step #c: vector<6xindex>
%9 = vector.step #d: vector<32xindex>
%10 = arith.mul %6, 384 #a
%11 = arith.mul %7, 192 #b
%12 = arith.mul %8, 32 #c
%13 = arith.mul %9, 1 #d
%14 = vector.shape_cast %10 #abcd: vector<4xindex> -> vector<4x1x1x1xbf16>
%15 = vector.shape_cast %11 #abcd: vector<2xindex> -> vector<1x2x1x1xbf16>
%16 = vector.shape_cast %12 #abcd: vector<6xindex> -> vector<1x1x6x1xbf16>
%17 = vector.shape_cast %13 #abcd: vector<32xindex> -> vector<1x1x1x32xbf16>
%18 = vector.broadcast %14 #abcd: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
%19 = vector.broadcast %15 #abcd: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
%20 = vector.broadcast %16 #abcd: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
%21 = vector.broadcast %17 #abcd: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
%22 = arith.add %18, %19 #abcd
%23 = arith.add %20, %21 #abcd
%local_offsets = arith.add %22, %23 #abcd
%orig_offset = %block_id_y * 1536
%offsets = orig_offset + local_offsets #abcd
%tdesc = xegpu.create_tdesc %expand_shape %offsets : memref<8x4x2x6x32xbf16>, vector<4x2x6x32xindex> -> !xegpu.tensor_desc<4x2x6x32xbf16, #xegpu.scatter_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = false>, #abcd>
%vec = xegpu.load_gather %tdesc #abcd: !xegpu.tensor_desc<4x2x6x32xbf16, #xegpu.scatter_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = false>> -> vector<4x2x6x32xbf16>