-
Notifications
You must be signed in to change notification settings - Fork 68
Closed
Description
Problem
The Coalescing pass for block pointers fails to handle scf::while loops. To reproduce the problem compile the following test. The pass will assert.
Test Case:
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.func public @kernel_make_tensor_descriptor_loop_carried(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c2_i32 = arith.constant 2 : i32
%4 = tt.make_tensor_ptr %arg0, [%arg1, %arg2], [%arg2, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #blocked>>
%5 = tt.advance %4, [%c2_i32, %c0_i32] : <tensor<8x128xf32, #blocked>>
%7 = arith.cmpi slt, %arg1, %arg2 : i64
%6:2 = scf.while (%arg3 = %4, %arg4 = %5) : (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) -> (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) {
scf.condition(%7) %arg3, %arg4 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
} do {
^bb0(%arg3: !tt.ptr<tensor<8x128xf32, #blocked>>, %arg4: !tt.ptr<tensor<8x128xf32, #blocked>>):
%12 = arith.select %7, %arg4, %arg3 : !tt.ptr<tensor<8x128xf32, #blocked>>
%13 = tt.advance %12, [%c0_i32, %c2_i32] : <tensor<8x128xf32, #blocked>>
%15 = tt.load %12 : !tt.ptr<tensor<8x128xf32, #blocked>>
tt.store %13, %15 : !tt.ptr<tensor<8x128xf32, #blocked>>
scf.yield %12, %13 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
}
tt.return
}
}
Metadata
Metadata
Assignees
Labels
No labels