Skip to content

[Coalescing] Failure to support scf::while loops. #4291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
etiotto opened this issue May 23, 2025 · 0 comments · May be fixed by #4290
Open

[Coalescing] Failure to support scf::while loops. #4291

etiotto opened this issue May 23, 2025 · 0 comments · May be fixed by #4290
Assignees

Comments

@etiotto
Copy link
Contributor

etiotto commented May 23, 2025

Problem
The Coalescing pass for block pointers fails to handle scf::while loops. To reproduce the problem compile the following test. The pass will assert.

Test Case:

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
  tt.func public @kernel_make_tensor_descriptor_loop_carried(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) {
    %c1_i64 = arith.constant 1 : i64
    %c0_i32 = arith.constant 0 : i32
    %c2_i32 = arith.constant 2 : i32    
    %4 = tt.make_tensor_ptr %arg0, [%arg1, %arg2], [%arg2, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #blocked>>
    %5 = tt.advance %4, [%c2_i32, %c0_i32] : <tensor<8x128xf32, #blocked>>
    %7 = arith.cmpi slt, %arg1, %arg2 : i64
    %6:2 = scf.while (%arg3 = %4, %arg4 = %5) : (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) -> (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) {
      scf.condition(%7) %arg3, %arg4 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
    } do {
    ^bb0(%arg3: !tt.ptr<tensor<8x128xf32, #blocked>>, %arg4: !tt.ptr<tensor<8x128xf32, #blocked>>):
      %12 = arith.select %7, %arg4, %arg3 : !tt.ptr<tensor<8x128xf32, #blocked>>
      %13 = tt.advance %12, [%c0_i32, %c2_i32] : <tensor<8x128xf32, #blocked>>
      %15 = tt.load %12 : !tt.ptr<tensor<8x128xf32, #blocked>>
      tt.store %13, %15 : !tt.ptr<tensor<8x128xf32, #blocked>>
      scf.yield %12, %13 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
    }
    tt.return
  }
}
@etiotto etiotto self-assigned this May 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant