Skip to content

[Helion] FlashAttention kernel containing 3-dim block-ptrs not optimized #5272

@etiotto

Description

@etiotto

Helium currently produces the following kernel containing 3-dim block pointers for its FlashAttention example. The XPU Triton compiler can optimize 2-dim block pointers load by lowering to efficient 2D block operations (on Xe2 targets). Given that the 3-dim block pointers have the outermost dim extent equal to one, the Triton compiler can "canonicalize" then to 2-dim block pointers by "fusing" the reshape operation with the associated load operation.

Motivating example:

module {
  tt.func public @_helion_attention(%q_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %k_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %v_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %out: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %total_pids = tt.call @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() : () -> i32
    %total_pids_0 = arith.constant 512 : i32
    %total_pids_1 = arith.constant 512 : i32
    %total_pids_2 = arith.muli %total_pids_1, %total_pids : i32
    %0 = tt.get_program_id x : i32
    %c20_i32 = arith.constant 20 : i32
    %1 = arith.bitcast %0 : i32 to i32
    %2 = arith.bitcast %total_pids_2 : i32 to i32
    %3 = arith.bitcast %c20_i32 : i32 to i32
    %4 = ub.poison : i32
    scf.for %virtual_pid = %1 to %2 step %3  : i32 {
      %num_pid_m = arith.constant 512 : i32
      %num_pid_n = tt.call @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() : () -> i32
      %num_pid_in_group = arith.constant 32 : i32
      %num_pid_in_group_3 = arith.constant 32 : i32
      %num_pid_in_group_4 = arith.muli %num_pid_in_group_3, %num_pid_n : i32
      %group_id = arith.divsi %virtual_pid, %num_pid_in_group_4 : i32
      %first_pid_m = arith.constant 32 : i32
      %first_pid_m_5 = arith.constant 32 : i32
      %first_pid_m_6 = arith.muli %group_id, %first_pid_m_5 : i32
      %group_size_m = arith.subi %num_pid_m, %first_pid_m_6 : i32
      %group_size_m_7 = arith.constant 32 : i32
      %group_size_m_8 = arith.minsi %group_size_m, %group_size_m_7 : i32
      %pid_0 = arith.remsi %virtual_pid, %num_pid_in_group_4 : i32
      %pid_0_9 = arith.remsi %pid_0, %group_size_m_8 : i32
      %pid_0_10 = arith.addi %first_pid_m_6, %pid_0_9 : i32
      %pid_1 = arith.remsi %virtual_pid, %num_pid_in_group_4 : i32
      %pid_1_11 = arith.divsi %pid_1, %group_size_m_8 : i32
      %offset_1 = arith.constant 512 : i32
      %offset_1_12 = arith.constant 512 : i32
      %offset_1_13 = arith.muli %pid_1_11, %offset_1_12 : i32
      %indices_1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
      %indices_1_14 = tt.splat %offset_1_13 : i32 -> tensor<512xi32>
      %indices_1_15 = arith.addi %indices_1_14, %indices_1 : tensor<512xi32>
      %mask_1 = arith.constant 1024 : i32
      %mask_1_16 = arith.constant dense<1024> : tensor<512xi32>
      %mask_1_17 = arith.cmpi slt, %indices_1_15, %mask_1_16 : tensor<512xi32>
      %tile_end = arith.constant 1 : i32
      %tile_end_18 = arith.constant 1 : i32
      %tile_end_19 = arith.addi %pid_0_10, %tile_end_18 : i32
      %c1_i32 = arith.constant 1 : i32
      %5 = arith.bitcast %pid_0_10 : i32 to i32
      %6 = arith.bitcast %tile_end_19 : i32 to i32
      %7 = arith.bitcast %c1_i32 : i32 to i32
      %8 = ub.poison : i32
      scf.for %offset_5 = %5 to %6 step %7  : i32 {
        %m_i = arith.constant 0xFF800000 : f32
        %m_i_20 = arith.constant dense<0xFF800000> : tensor<512xf32>
        %l_i = arith.constant 1.000000e+00 : f32
        %l_i_21 = arith.constant dense<1.000000e+00> : tensor<512xf32>
        %acc = arith.constant 0.000000e+00 : f32
        %acc_22 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
        %q = arith.constant 512 : i64
        %q_23 = arith.constant 1024 : i64
        %q_24 = arith.constant 64 : i64
        %q_25 = arith.constant 65536 : i64
        %q_26 = arith.constant 64 : i64
        %q_27 = arith.constant 1 : i64
        %q_28 = arith.constant 0 : i32
        %q_29 = tt.make_tensor_ptr %q_view, [%q, %q_23, %q_24], [%q_25, %q_26, %q_27], [%offset_5, %offset_1_13, %q_28] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
        %q_30 = tt.load %q_29 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>>
        %q_31 = tt.reshape %q_30 : tensor<1x512x64xf16> -> tensor<512x64xf16>
        %c0_i32 = arith.constant 0 : i32
        %c1024_i32 = arith.constant 1024 : i32
        %c64_i32 = arith.constant 64 : i32
        %9 = arith.bitcast %c0_i32 : i32 to i32
        %10 = arith.bitcast %c1024_i32 : i32 to i32
        %11 = arith.bitcast %c64_i32 : i32 to i32
        %12 = ub.poison : i32
        %acc_32:3 = scf.for %offset_3 = %9 to %10 step %11 iter_args(%m_i_36 = %m_i_20, %l_i_37 = %l_i_21, %acc_38 = %acc_22) -> (tensor<512xf32>, tensor<512xf32>, tensor<512x64xf32>)  : i32 {
          %indices_3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
          %indices_3_39 = tt.splat %offset_3 : i32 -> tensor<64xi32>
          %indices_3_40 = arith.addi %indices_3_39, %indices_3 : tensor<64xi32>
          %mask_4 = arith.constant 1024 : i32
          %mask_4_41 = arith.constant dense<1024> : tensor<64xi32>
          %mask_4_42 = arith.cmpi slt, %indices_3_40, %mask_4_41 : tensor<64xi32>
          %k = arith.constant 512 : i64
          %k_43 = arith.constant 64 : i64
          %k_44 = arith.constant 1024 : i64
          %k_45 = arith.constant 65536 : i64
          %k_46 = arith.constant 1 : i64
          %k_47 = arith.constant 64 : i64
          %k_48 = arith.constant 0 : i32
          %k_49 = tt.make_tensor_ptr %k_view, [%k, %k_43, %k_44], [%k_45, %k_46, %k_47], [%offset_5, %k_48, %offset_3] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>>
          %k_50 = tt.load %k_49 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x64x64xf16>>
          %k_51 = tt.reshape %k_50 : tensor<1x64x64xf16> -> tensor<64x64xf16>
          %qk = arith.constant 0.000000e+00 : f32
          %qk_52 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
          %qk_53 = tt.dot %q_31, %k_51, %qk_52, inputPrecision = tf32 : tensor<512x64xf16> * tensor<64x64xf16> -> tensor<512x64xf32>
          %_mask_to_2 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
          %_mask_to_2_54 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
          %_mask_to_2_55 = tt.broadcast %_mask_to_2 : tensor<512x1xi1> -> tensor<512x64xi1>
          %_mask_to_2_56 = tt.broadcast %_mask_to_2_54 : tensor<1x64xi1> -> tensor<512x64xi1>
          %_mask_to_2_57 = arith.andi %_mask_to_2_55, %_mask_to_2_56 : tensor<512x64xi1>
          %_mask_to_2_58 = arith.constant 0xFC00 : f16
          %_mask_to_2_59 = arith.extf %_mask_to_2_58 : f16 to f32
          %_mask_to_2_60 = tt.splat %_mask_to_2_59 : f32 -> tensor<512x64xf32>
          %_mask_to_2_61 = arith.select %_mask_to_2_57, %qk_53, %_mask_to_2_60 : tensor<512x64xi1>, tensor<512x64xf32>
          %amax = tt.call @"triton.language.standard.max__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%_mask_to_2_61) : (tensor<512x64xf32>) -> tensor<512xf32>
          %amax_62 = arith.truncf %amax : tensor<512xf32> to tensor<512xf16>
          %v_0 = arith.constant 0.180336878 : f32
          %v_1 = arith.extf %amax_62 : tensor<512xf16> to tensor<512xf32>
          %v_1_63 = arith.constant dense<0.180336878> : tensor<512xf32>
          %v_1_64 = arith.mulf %v_1, %v_1_63 : tensor<512xf32>
          %v_3 = tt.call @torch._inductor.runtime.triton_helpers.maximum__fp32S512S_fp32S512S__(%m_i_36, %v_1_64) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
          %v_4 = arith.constant 0.180336878 : f32
          %v_5 = arith.constant dense<0.180336878> : tensor<512x64xf32>
          %v_5_65 = arith.mulf %qk_53, %v_5 : tensor<512x64xf32>
          %subscript = tt.expand_dims %v_3 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
          %v_7 = tt.broadcast %subscript : tensor<512x1xf32> -> tensor<512x64xf32>
          %v_7_66 = arith.subf %v_5_65, %v_7 : tensor<512x64xf32>
          %v_8 = tt.extern_elementwise %v_7_66 {libname = "", libpath = "", pure = true, symbol = "__imf_exp2f"} : (tensor<512x64xf32>) -> tensor<512x64xf32>
          %_mask_to_3 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
          %_mask_to_3_67 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
          %_mask_to_3_68 = tt.broadcast %_mask_to_3 : tensor<512x1xi1> -> tensor<512x64xi1>
          %_mask_to_3_69 = tt.broadcast %_mask_to_3_67 : tensor<1x64xi1> -> tensor<512x64xi1>
          %_mask_to_3_70 = arith.andi %_mask_to_3_68, %_mask_to_3_69 : tensor<512x64xi1>
          %_mask_to_3_71 = arith.constant 0.000000e+00 : f32
          %_mask_to_3_72 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
          %_mask_to_3_73 = arith.select %_mask_to_3_70, %v_8, %_mask_to_3_72 : tensor<512x64xi1>, tensor<512x64xf32>
          %l_ij = tt.call @"triton.language.standard.sum__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_mask_to_3_73) : (tensor<512x64xf32>) -> tensor<512xf32>
          %v_9 = arith.subf %m_i_36, %v_3 : tensor<512xf32>
          %v_10 = tt.extern_elementwise %v_9 {libname = "", libpath = "", pure = true, symbol = "__imf_exp2f"} : (tensor<512xf32>) -> tensor<512xf32>
          %v_11 = arith.mulf %l_i_37, %v_10 : tensor<512xf32>
          %l_i_74 = arith.addf %v_11, %l_ij : tensor<512xf32>
          %subscript_1 = tt.expand_dims %v_10 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
          %v_13 = tt.broadcast %subscript_1 : tensor<512x1xf32> -> tensor<512x64xf32>
          %v_13_75 = arith.mulf %acc_38, %v_13 : tensor<512x64xf32>
          %v = arith.constant 512 : i64
          %v_76 = arith.constant 1024 : i64
          %v_77 = arith.constant 64 : i64
          %v_78 = arith.constant 65536 : i64
          %v_79 = arith.constant 64 : i64
          %v_80 = arith.constant 1 : i64
          %v_81 = arith.constant 0 : i32
          %v_82 = tt.make_tensor_ptr %v_view, [%v, %v_76, %v_77], [%v_78, %v_79, %v_80], [%offset_5, %offset_3, %v_81] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>>
          %v_83 = tt.load %v_82 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x64x64xf16>>
          %v_84 = tt.reshape %v_83 : tensor<1x64x64xf16> -> tensor<64x64xf16>
          %v_14 = arith.truncf %v_8 : tensor<512x64xf32> to tensor<512x64xf16>
          %_mask_to_4 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
          %_mask_to_4_85 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
          %_mask_to_4_86 = tt.broadcast %_mask_to_4 : tensor<512x1xi1> -> tensor<512x64xi1>
          %_mask_to_4_87 = tt.broadcast %_mask_to_4_85 : tensor<1x64xi1> -> tensor<512x64xi1>
          %_mask_to_4_88 = arith.andi %_mask_to_4_86, %_mask_to_4_87 : tensor<512x64xi1>
          %_mask_to_4_89 = arith.constant 0.000000e+00 : f16
          %_mask_to_4_90 = arith.constant dense<0.000000e+00> : tensor<512x64xf16>
          %_mask_to_4_91 = arith.select %_mask_to_4_88, %v_14, %_mask_to_4_90 : tensor<512x64xi1>, tensor<512x64xf16>
          %acc_92 = arith.constant 0.000000e+00 : f32
          %acc_93 = tt.dot %_mask_to_4_91, %v_84, %v_13_75, inputPrecision = tf32 : tensor<512x64xf16> * tensor<64x64xf16> -> tensor<512x64xf32>
          scf.yield %v_3, %l_i_74, %acc_93 : tensor<512xf32>, tensor<512xf32>, tensor<512x64xf32>
        } {tt.disallow_acc_multi_buffer, tt.loop_unroll_factor = 1 : i32}
        %subscript_2 = tt.expand_dims %acc_32#1 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
        %v_15 = tt.broadcast %subscript_2 : tensor<512x1xf32> -> tensor<512x64xf32>
        %v_15_33 = arith.divf %acc_32#2, %v_15 : tensor<512x64xf32>
        %v_16 = arith.truncf %v_15_33 : tensor<512x64xf32> to tensor<512x64xf16>
        %c512_i64 = arith.constant 512 : i64
        %c1024_i64 = arith.constant 1024 : i64
        %c64_i64 = arith.constant 64 : i64
        %c65536_i64 = arith.constant 65536 : i64
        %c64_i64_34 = arith.constant 64 : i64
        %c1_i64 = arith.constant 1 : i64
        %c0_i32_35 = arith.constant 0 : i32
        %13 = tt.make_tensor_ptr %out, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64_34, %c1_i64], [%offset_5, %offset_1_13, %c0_i32_35] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
        %14 = tt.reshape %v_16 : tensor<512x64xf16> -> tensor<1x512x64xf16>
        tt.store %13, %14 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>>
      } {tt.disallow_acc_multi_buffer, tt.loop_unroll_factor = 3 : i32, tt.num_stages = 3 : i32}
    } {tt.loop_unroll_factor = 1 : i32, tt.num_stages = 4 : i32}
    tt.return
  }
  tt.func private @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() -> i32 attributes {noinline = false} {
    %c2_i32 = arith.constant 2 : i32
    tt.return %c2_i32 : i32
  ^bb1:  // no predecessors
    %0 = ub.poison : i32
    tt.return %0 : i32
  }
  tt.func private @"triton.language.standard.max__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%input: tensor<512x64xf32>) -> tensor<512xf32> attributes {noinline = false} {
    %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({
    ^bb0(%arg1: f32, %arg2: f32):
      %2 = tt.call @triton.language.standard._elementwise_max__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
      tt.reduce.return %2 : f32
    }) : (tensor<512x64xf32>) -> tensor<512xf32>
    tt.return %0 : tensor<512xf32>
  ^bb1:  // no predecessors
    %1 = ub.poison : tensor<512xf32>
    tt.return %1 : tensor<512xf32>
  }
  tt.func private @triton.language.standard._elementwise_max__fp32_fp32__(%a: f32, %b: f32) -> f32 attributes {noinline = false} {
    %0 = arith.maxnumf %a, %b : f32
    tt.return %0 : f32
  ^bb1:  // no predecessors
    %1 = ub.poison : f32
    tt.return %1 : f32
  }
  tt.func private @torch._inductor.runtime.triton_helpers.maximum__fp32S512S_fp32S512S__(%a: tensor<512xf32>, %b: tensor<512xf32>) -> tensor<512xf32> attributes {noinline = false} {
    %mask = arith.cmpf ogt, %a, %b : tensor<512xf32>
    %0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__fp32S512S__(%a) : (tensor<512xf32>) -> i1
    %1 = scf.if %0 -> (tensor<512xi1>) {
      %mask_0 = arith.cmpf une, %a, %a : tensor<512xf32>
      %mask_1 = arith.ori %mask, %mask_0 : tensor<512xi1>
      scf.yield %mask_1 : tensor<512xi1>
    } else {
      scf.yield %mask : tensor<512xi1>
    }
    %2 = arith.select %1, %a, %b : tensor<512xi1>, tensor<512xf32>
    tt.return %2 : tensor<512xf32>
  ^bb1:  // no predecessors
    %3 = ub.poison : tensor<512xf32>
    tt.return %3 : tensor<512xf32>
  }
  tt.func private @torch._inductor.runtime.triton_helpers.is_floating__fp32S512S__(%x: tensor<512xf32>) -> i1 attributes {noinline = false} {
    %0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S512S__(%x) : (tensor<512xf32>) -> tensor<512xf32>
    %true = arith.constant true
    tt.return %true : i1
  ^bb1:  // no predecessors
    %1 = ub.poison : i1
    tt.return %1 : i1
  }
  tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S512S__(%x: tensor<512xf32>) -> tensor<512xf32> attributes {noinline = false} {
    %0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1>
    %1 = arith.uitofp %0 : tensor<1xi1> to tensor<1xf32>
    %2 = tt.broadcast %1 : tensor<1xf32> -> tensor<512xf32>
    %3 = arith.addf %x, %2 : tensor<512xf32>
    tt.return %3 : tensor<512xf32>
  ^bb1:  // no predecessors
    %4 = ub.poison : tensor<512xf32>
    tt.return %4 : tensor<512xf32>
  }
  tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} {
    %false = arith.constant false
    %cst = arith.constant dense<false> : tensor<1xi1>
    tt.return %cst : tensor<1xi1>
  ^bb1:  // no predecessors
    %0 = ub.poison : tensor<1xi1>
    tt.return %0 : tensor<1xi1>
  }
  tt.func private @"triton.language.standard.sum__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<512x64xf32>) -> tensor<512xf32> attributes {noinline = false} {
    %0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({
    ^bb0(%arg1: f32, %arg2: f32):
      %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
      tt.reduce.return %2 : f32
    }) : (tensor<512x64xf32>) -> tensor<512xf32>
    tt.return %0 : tensor<512xf32>
  ^bb1:  // no predecessors
    %1 = ub.poison : tensor<512xf32>
    tt.return %1 : tensor<512xf32>
  }
  tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32, %b: f32) -> f32 attributes {noinline = false} {
    %0 = arith.addf %a, %b : f32
    tt.return %0 : f32
  ^bb1:  // no predecessors
    %1 = ub.poison : f32
    tt.return %1 : f32
  }
}

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions