-
Notifications
You must be signed in to change notification settings - Fork 74
Description
Helium currently produces the following kernel containing 3-dim block pointers for its FlashAttention example. The XPU Triton compiler can optimize 2-dim block pointers load by lowering to efficient 2D block operations (on Xe2 targets). Given that the 3-dim block pointers have the outermost dim extent equal to one, the Triton compiler can "canonicalize" then to 2-dim block pointers by "fusing" the reshape operation with the associated load operation.
Motivating example:
module {
tt.func public @_helion_attention(%q_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %k_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %v_view: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %out: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%total_pids = tt.call @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() : () -> i32
%total_pids_0 = arith.constant 512 : i32
%total_pids_1 = arith.constant 512 : i32
%total_pids_2 = arith.muli %total_pids_1, %total_pids : i32
%0 = tt.get_program_id x : i32
%c20_i32 = arith.constant 20 : i32
%1 = arith.bitcast %0 : i32 to i32
%2 = arith.bitcast %total_pids_2 : i32 to i32
%3 = arith.bitcast %c20_i32 : i32 to i32
%4 = ub.poison : i32
scf.for %virtual_pid = %1 to %2 step %3 : i32 {
%num_pid_m = arith.constant 512 : i32
%num_pid_n = tt.call @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() : () -> i32
%num_pid_in_group = arith.constant 32 : i32
%num_pid_in_group_3 = arith.constant 32 : i32
%num_pid_in_group_4 = arith.muli %num_pid_in_group_3, %num_pid_n : i32
%group_id = arith.divsi %virtual_pid, %num_pid_in_group_4 : i32
%first_pid_m = arith.constant 32 : i32
%first_pid_m_5 = arith.constant 32 : i32
%first_pid_m_6 = arith.muli %group_id, %first_pid_m_5 : i32
%group_size_m = arith.subi %num_pid_m, %first_pid_m_6 : i32
%group_size_m_7 = arith.constant 32 : i32
%group_size_m_8 = arith.minsi %group_size_m, %group_size_m_7 : i32
%pid_0 = arith.remsi %virtual_pid, %num_pid_in_group_4 : i32
%pid_0_9 = arith.remsi %pid_0, %group_size_m_8 : i32
%pid_0_10 = arith.addi %first_pid_m_6, %pid_0_9 : i32
%pid_1 = arith.remsi %virtual_pid, %num_pid_in_group_4 : i32
%pid_1_11 = arith.divsi %pid_1, %group_size_m_8 : i32
%offset_1 = arith.constant 512 : i32
%offset_1_12 = arith.constant 512 : i32
%offset_1_13 = arith.muli %pid_1_11, %offset_1_12 : i32
%indices_1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
%indices_1_14 = tt.splat %offset_1_13 : i32 -> tensor<512xi32>
%indices_1_15 = arith.addi %indices_1_14, %indices_1 : tensor<512xi32>
%mask_1 = arith.constant 1024 : i32
%mask_1_16 = arith.constant dense<1024> : tensor<512xi32>
%mask_1_17 = arith.cmpi slt, %indices_1_15, %mask_1_16 : tensor<512xi32>
%tile_end = arith.constant 1 : i32
%tile_end_18 = arith.constant 1 : i32
%tile_end_19 = arith.addi %pid_0_10, %tile_end_18 : i32
%c1_i32 = arith.constant 1 : i32
%5 = arith.bitcast %pid_0_10 : i32 to i32
%6 = arith.bitcast %tile_end_19 : i32 to i32
%7 = arith.bitcast %c1_i32 : i32 to i32
%8 = ub.poison : i32
scf.for %offset_5 = %5 to %6 step %7 : i32 {
%m_i = arith.constant 0xFF800000 : f32
%m_i_20 = arith.constant dense<0xFF800000> : tensor<512xf32>
%l_i = arith.constant 1.000000e+00 : f32
%l_i_21 = arith.constant dense<1.000000e+00> : tensor<512xf32>
%acc = arith.constant 0.000000e+00 : f32
%acc_22 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
%q = arith.constant 512 : i64
%q_23 = arith.constant 1024 : i64
%q_24 = arith.constant 64 : i64
%q_25 = arith.constant 65536 : i64
%q_26 = arith.constant 64 : i64
%q_27 = arith.constant 1 : i64
%q_28 = arith.constant 0 : i32
%q_29 = tt.make_tensor_ptr %q_view, [%q, %q_23, %q_24], [%q_25, %q_26, %q_27], [%offset_5, %offset_1_13, %q_28] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
%q_30 = tt.load %q_29 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x512x64xf16>>
%q_31 = tt.reshape %q_30 : tensor<1x512x64xf16> -> tensor<512x64xf16>
%c0_i32 = arith.constant 0 : i32
%c1024_i32 = arith.constant 1024 : i32
%c64_i32 = arith.constant 64 : i32
%9 = arith.bitcast %c0_i32 : i32 to i32
%10 = arith.bitcast %c1024_i32 : i32 to i32
%11 = arith.bitcast %c64_i32 : i32 to i32
%12 = ub.poison : i32
%acc_32:3 = scf.for %offset_3 = %9 to %10 step %11 iter_args(%m_i_36 = %m_i_20, %l_i_37 = %l_i_21, %acc_38 = %acc_22) -> (tensor<512xf32>, tensor<512xf32>, tensor<512x64xf32>) : i32 {
%indices_3 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%indices_3_39 = tt.splat %offset_3 : i32 -> tensor<64xi32>
%indices_3_40 = arith.addi %indices_3_39, %indices_3 : tensor<64xi32>
%mask_4 = arith.constant 1024 : i32
%mask_4_41 = arith.constant dense<1024> : tensor<64xi32>
%mask_4_42 = arith.cmpi slt, %indices_3_40, %mask_4_41 : tensor<64xi32>
%k = arith.constant 512 : i64
%k_43 = arith.constant 64 : i64
%k_44 = arith.constant 1024 : i64
%k_45 = arith.constant 65536 : i64
%k_46 = arith.constant 1 : i64
%k_47 = arith.constant 64 : i64
%k_48 = arith.constant 0 : i32
%k_49 = tt.make_tensor_ptr %k_view, [%k, %k_43, %k_44], [%k_45, %k_46, %k_47], [%offset_5, %k_48, %offset_3] {order = array<i32: 2, 0, 1>} : <tensor<1x64x64xf16>>
%k_50 = tt.load %k_49 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x64x64xf16>>
%k_51 = tt.reshape %k_50 : tensor<1x64x64xf16> -> tensor<64x64xf16>
%qk = arith.constant 0.000000e+00 : f32
%qk_52 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
%qk_53 = tt.dot %q_31, %k_51, %qk_52, inputPrecision = tf32 : tensor<512x64xf16> * tensor<64x64xf16> -> tensor<512x64xf32>
%_mask_to_2 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
%_mask_to_2_54 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
%_mask_to_2_55 = tt.broadcast %_mask_to_2 : tensor<512x1xi1> -> tensor<512x64xi1>
%_mask_to_2_56 = tt.broadcast %_mask_to_2_54 : tensor<1x64xi1> -> tensor<512x64xi1>
%_mask_to_2_57 = arith.andi %_mask_to_2_55, %_mask_to_2_56 : tensor<512x64xi1>
%_mask_to_2_58 = arith.constant 0xFC00 : f16
%_mask_to_2_59 = arith.extf %_mask_to_2_58 : f16 to f32
%_mask_to_2_60 = tt.splat %_mask_to_2_59 : f32 -> tensor<512x64xf32>
%_mask_to_2_61 = arith.select %_mask_to_2_57, %qk_53, %_mask_to_2_60 : tensor<512x64xi1>, tensor<512x64xf32>
%amax = tt.call @"triton.language.standard.max__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%_mask_to_2_61) : (tensor<512x64xf32>) -> tensor<512xf32>
%amax_62 = arith.truncf %amax : tensor<512xf32> to tensor<512xf16>
%v_0 = arith.constant 0.180336878 : f32
%v_1 = arith.extf %amax_62 : tensor<512xf16> to tensor<512xf32>
%v_1_63 = arith.constant dense<0.180336878> : tensor<512xf32>
%v_1_64 = arith.mulf %v_1, %v_1_63 : tensor<512xf32>
%v_3 = tt.call @torch._inductor.runtime.triton_helpers.maximum__fp32S512S_fp32S512S__(%m_i_36, %v_1_64) : (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>
%v_4 = arith.constant 0.180336878 : f32
%v_5 = arith.constant dense<0.180336878> : tensor<512x64xf32>
%v_5_65 = arith.mulf %qk_53, %v_5 : tensor<512x64xf32>
%subscript = tt.expand_dims %v_3 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
%v_7 = tt.broadcast %subscript : tensor<512x1xf32> -> tensor<512x64xf32>
%v_7_66 = arith.subf %v_5_65, %v_7 : tensor<512x64xf32>
%v_8 = tt.extern_elementwise %v_7_66 {libname = "", libpath = "", pure = true, symbol = "__imf_exp2f"} : (tensor<512x64xf32>) -> tensor<512x64xf32>
%_mask_to_3 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
%_mask_to_3_67 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
%_mask_to_3_68 = tt.broadcast %_mask_to_3 : tensor<512x1xi1> -> tensor<512x64xi1>
%_mask_to_3_69 = tt.broadcast %_mask_to_3_67 : tensor<1x64xi1> -> tensor<512x64xi1>
%_mask_to_3_70 = arith.andi %_mask_to_3_68, %_mask_to_3_69 : tensor<512x64xi1>
%_mask_to_3_71 = arith.constant 0.000000e+00 : f32
%_mask_to_3_72 = arith.constant dense<0.000000e+00> : tensor<512x64xf32>
%_mask_to_3_73 = arith.select %_mask_to_3_70, %v_8, %_mask_to_3_72 : tensor<512x64xi1>, tensor<512x64xf32>
%l_ij = tt.call @"triton.language.standard.sum__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%_mask_to_3_73) : (tensor<512x64xf32>) -> tensor<512xf32>
%v_9 = arith.subf %m_i_36, %v_3 : tensor<512xf32>
%v_10 = tt.extern_elementwise %v_9 {libname = "", libpath = "", pure = true, symbol = "__imf_exp2f"} : (tensor<512xf32>) -> tensor<512xf32>
%v_11 = arith.mulf %l_i_37, %v_10 : tensor<512xf32>
%l_i_74 = arith.addf %v_11, %l_ij : tensor<512xf32>
%subscript_1 = tt.expand_dims %v_10 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
%v_13 = tt.broadcast %subscript_1 : tensor<512x1xf32> -> tensor<512x64xf32>
%v_13_75 = arith.mulf %acc_38, %v_13 : tensor<512x64xf32>
%v = arith.constant 512 : i64
%v_76 = arith.constant 1024 : i64
%v_77 = arith.constant 64 : i64
%v_78 = arith.constant 65536 : i64
%v_79 = arith.constant 64 : i64
%v_80 = arith.constant 1 : i64
%v_81 = arith.constant 0 : i32
%v_82 = tt.make_tensor_ptr %v_view, [%v, %v_76, %v_77], [%v_78, %v_79, %v_80], [%offset_5, %offset_3, %v_81] {order = array<i32: 2, 1, 0>} : <tensor<1x64x64xf16>>
%v_83 = tt.load %v_82 {boundaryCheck = array<i32: 1, 2>, padding = 1 : i32} : !tt.ptr<tensor<1x64x64xf16>>
%v_84 = tt.reshape %v_83 : tensor<1x64x64xf16> -> tensor<64x64xf16>
%v_14 = arith.truncf %v_8 : tensor<512x64xf32> to tensor<512x64xf16>
%_mask_to_4 = tt.expand_dims %mask_1_17 {axis = 1 : i32} : tensor<512xi1> -> tensor<512x1xi1>
%_mask_to_4_85 = tt.expand_dims %mask_4_42 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
%_mask_to_4_86 = tt.broadcast %_mask_to_4 : tensor<512x1xi1> -> tensor<512x64xi1>
%_mask_to_4_87 = tt.broadcast %_mask_to_4_85 : tensor<1x64xi1> -> tensor<512x64xi1>
%_mask_to_4_88 = arith.andi %_mask_to_4_86, %_mask_to_4_87 : tensor<512x64xi1>
%_mask_to_4_89 = arith.constant 0.000000e+00 : f16
%_mask_to_4_90 = arith.constant dense<0.000000e+00> : tensor<512x64xf16>
%_mask_to_4_91 = arith.select %_mask_to_4_88, %v_14, %_mask_to_4_90 : tensor<512x64xi1>, tensor<512x64xf16>
%acc_92 = arith.constant 0.000000e+00 : f32
%acc_93 = tt.dot %_mask_to_4_91, %v_84, %v_13_75, inputPrecision = tf32 : tensor<512x64xf16> * tensor<64x64xf16> -> tensor<512x64xf32>
scf.yield %v_3, %l_i_74, %acc_93 : tensor<512xf32>, tensor<512xf32>, tensor<512x64xf32>
} {tt.disallow_acc_multi_buffer, tt.loop_unroll_factor = 1 : i32}
%subscript_2 = tt.expand_dims %acc_32#1 {axis = 1 : i32} : tensor<512xf32> -> tensor<512x1xf32>
%v_15 = tt.broadcast %subscript_2 : tensor<512x1xf32> -> tensor<512x64xf32>
%v_15_33 = arith.divf %acc_32#2, %v_15 : tensor<512x64xf32>
%v_16 = arith.truncf %v_15_33 : tensor<512x64xf32> to tensor<512x64xf16>
%c512_i64 = arith.constant 512 : i64
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c65536_i64 = arith.constant 65536 : i64
%c64_i64_34 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32_35 = arith.constant 0 : i32
%13 = tt.make_tensor_ptr %out, [%c512_i64, %c1024_i64, %c64_i64], [%c65536_i64, %c64_i64_34, %c1_i64], [%offset_5, %offset_1_13, %c0_i32_35] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
%14 = tt.reshape %v_16 : tensor<512x64xf16> -> tensor<1x512x64xf16>
tt.store %13, %14 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>>
} {tt.disallow_acc_multi_buffer, tt.loop_unroll_factor = 3 : i32, tt.num_stages = 3 : i32}
} {tt.loop_unroll_factor = 1 : i32, tt.num_stages = 4 : i32}
tt.return
}
tt.func private @"triton.language.standard.cdiv____(0,)cconstexpr_1024__(1,)cconstexpr_512_"() -> i32 attributes {noinline = false} {
%c2_i32 = arith.constant 2 : i32
tt.return %c2_i32 : i32
^bb1: // no predecessors
%0 = ub.poison : i32
tt.return %0 : i32
}
tt.func private @"triton.language.standard.max__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%input: tensor<512x64xf32>) -> tensor<512xf32> attributes {noinline = false} {
%0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%2 = tt.call @triton.language.standard._elementwise_max__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
tt.reduce.return %2 : f32
}) : (tensor<512x64xf32>) -> tensor<512xf32>
tt.return %0 : tensor<512xf32>
^bb1: // no predecessors
%1 = ub.poison : tensor<512xf32>
tt.return %1 : tensor<512xf32>
}
tt.func private @triton.language.standard._elementwise_max__fp32_fp32__(%a: f32, %b: f32) -> f32 attributes {noinline = false} {
%0 = arith.maxnumf %a, %b : f32
tt.return %0 : f32
^bb1: // no predecessors
%1 = ub.poison : f32
tt.return %1 : f32
}
tt.func private @torch._inductor.runtime.triton_helpers.maximum__fp32S512S_fp32S512S__(%a: tensor<512xf32>, %b: tensor<512xf32>) -> tensor<512xf32> attributes {noinline = false} {
%mask = arith.cmpf ogt, %a, %b : tensor<512xf32>
%0 = tt.call @torch._inductor.runtime.triton_helpers.is_floating__fp32S512S__(%a) : (tensor<512xf32>) -> i1
%1 = scf.if %0 -> (tensor<512xi1>) {
%mask_0 = arith.cmpf une, %a, %a : tensor<512xf32>
%mask_1 = arith.ori %mask, %mask_0 : tensor<512xi1>
scf.yield %mask_1 : tensor<512xi1>
} else {
scf.yield %mask : tensor<512xi1>
}
%2 = arith.select %1, %a, %b : tensor<512xi1>, tensor<512xf32>
tt.return %2 : tensor<512xf32>
^bb1: // no predecessors
%3 = ub.poison : tensor<512xf32>
tt.return %3 : tensor<512xf32>
}
tt.func private @torch._inductor.runtime.triton_helpers.is_floating__fp32S512S__(%x: tensor<512xf32>) -> i1 attributes {noinline = false} {
%0 = tt.call @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S512S__(%x) : (tensor<512xf32>) -> tensor<512xf32>
%true = arith.constant true
tt.return %true : i1
^bb1: // no predecessors
%1 = ub.poison : i1
tt.return %1 : i1
}
tt.func private @torch._inductor.runtime.triton_helpers.promote_to_tensor__fp32S512S__(%x: tensor<512xf32>) -> tensor<512xf32> attributes {noinline = false} {
%0 = tt.call @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() : () -> tensor<1xi1>
%1 = arith.uitofp %0 : tensor<1xi1> to tensor<1xf32>
%2 = tt.broadcast %1 : tensor<1xf32> -> tensor<512xf32>
%3 = arith.addf %x, %2 : tensor<512xf32>
tt.return %3 : tensor<512xf32>
^bb1: // no predecessors
%4 = ub.poison : tensor<512xf32>
tt.return %4 : tensor<512xf32>
}
tt.func private @"triton.language.standard.zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int1_"() -> tensor<1xi1> attributes {noinline = false} {
%false = arith.constant false
%cst = arith.constant dense<false> : tensor<1xi1>
tt.return %cst : tensor<1xi1>
^bb1: // no predecessors
%0 = ub.poison : tensor<1xi1>
tt.return %0 : tensor<1xi1>
}
tt.func private @"triton.language.standard.sum__fp32S512_64S__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%input: tensor<512x64xf32>) -> tensor<512xf32> attributes {noinline = false} {
%0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
%2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32
tt.reduce.return %2 : f32
}) : (tensor<512x64xf32>) -> tensor<512xf32>
tt.return %0 : tensor<512xf32>
^bb1: // no predecessors
%1 = ub.poison : tensor<512xf32>
tt.return %1 : tensor<512xf32>
}
tt.func private @triton.language.standard._sum_combine__fp32_fp32__(%a: f32, %b: f32) -> f32 attributes {noinline = false} {
%0 = arith.addf %a, %b : f32
tt.return %0 : f32
^bb1: // no predecessors
%1 = ub.poison : f32
tt.return %1 : f32
}
}