diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 1f9ecf8bd8..c865bd257d 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -1,3 +1,6 @@ +import itertools + +import numpy as np import pytest import torch import pathlib @@ -7,6 +10,27 @@ @pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32]]) +class DpasLayout: + + def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta, + rep_cluster): + self.repeatCount = repeatCount + self.systolic_depth = systolic_depth + self.execution_size = execution_size + self.ops_per_chan = ops_per_chan + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.rep_cluster = rep_cluster + + def __str__(self): + return f"#triton_intel_gpu.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>" + + +def warps_per_cta(layout): + return layout.warps_per_cta + + +@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [64, 64], [64, 32], [32, 32]]) @pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"]) @pytest.mark.parametrize("transpose", [True, False]) @pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend") @@ -15,8 +39,6 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pathlib.Path): # modify the layouts to ensure the correct OCL/SPIRV intrinsic is called for each datatype if dtype_str == "int8": - if M == 128 and N == 16 or N == 8: - pytest.skip("TODO: test fails verification") A_width = 2 B_width = 4 layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2], A = [8, 32], B = [32, 32], C = [8, 32]}>" @@ -25,8 +47,6 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa B_width = 1 layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>" else: - if M == 128 and N == 8: - pytest.skip("TODO: test fails verification") A_width = 1 B_width = 2 layouts = "#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>" @@ -79,3 +99,107 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa kernel[(1, 1, 1)](a, x, b, y) #import pdb; pdb.set_trace() assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y) + + +layouts = [ + # Layout for Xe2 and Xe2+ + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16, + warps_per_cta=[1, 4], rep_cluster=[1, 2]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16, + warps_per_cta=[8, 4], rep_cluster=[4, 2]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=16, + warps_per_cta=[8, 4], rep_cluster=[1, 1]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=32, + warps_per_cta=[1, 4], rep_cluster=[1, 2]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32, + warps_per_cta=[8, 4], rep_cluster=[4, 2]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32, + warps_per_cta=[8, 4], rep_cluster=[1, 1]), + # Layout for Xe +] + + +@pytest.mark.parametrize("M, N", [[M, N] for M, N in itertools.product([32, 64, 128, 256], [32, 64, 128, 256])]) +@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"]) +@pytest.mark.parametrize("layout", layouts) +@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend") +def test_tensor_pointer_block_load(M, N, dtype_str, layout, device, tmp_path: pathlib.Path): + + warps = warps_per_cta(layout) + num_warps = int(np.prod(warps)) + threads_per_warp = layout.threads_per_warp + ops_per_chan = layout.ops_per_chan + A_width = 1 if ops_per_chan == 1 else ops_per_chan // 2 + B_width = ops_per_chan + + ty = {"float32": "f32", "float16": "f16", "int8": "i8"}[dtype_str] + + support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] + + ir = f""" + #mma = {layout} + #dot_a = #ttg.dot_op<{{opIdx = 0, parent = #mma, kWidth = {A_width}}}> + #dot_b = #ttg.dot_op<{{opIdx = 1, parent = #mma, kWidth = {B_width}}}> + module attributes {{triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, {"triton_intel_gpu.support_sg_2d_block," if support_block_io else ""} triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{ + tt.func public @tensor_pointer_block_load(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg6: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg7: i32 {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + // A matrix + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> + %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_a}}>> -> tensor<{M}x1xi32, #dot_a> + %3 = tt.splat %arg6 : i32 -> tensor<{M}x1xi32, #dot_a> + %4 = arith.muli %2, %3 : tensor<{M}x1xi32, #dot_a> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_a}}>> -> tensor<1x{N}xi32, #dot_a> + %7 = tt.broadcast %4 : tensor<{M}x1xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #dot_a> -> tensor<{M}x{N}xi32, #dot_a> + %9 = arith.addi %7, %8 : tensor<{M}x{N}xi32, #dot_a> + + %10 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a> + %11 = tt.addptr %10, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a> + %12 = tt.load %11 {{triton_intel_gpu.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a> + %13 = tt.splat %arg1 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a> + %14 = tt.addptr %13, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a>, tensor<{M}x{N}xi32, #dot_a> + tt.store %14, %12 {{boundaryCheck = array}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_a> + + // B matrix + %22 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> + %44 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> + %46 = tt.expand_dims %44 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dot_b}}>> -> tensor<{M}x1xi32, #dot_b> + %48 = tt.splat %arg7 : i32 -> tensor<{M}x1xi32, #dot_b> + %49 = arith.muli %46, %48 : tensor<{M}x1xi32, #dot_b> + %50 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dot_b}}>> -> tensor<1x{N}xi32, #dot_b> + %51 = tt.broadcast %49 : tensor<{M}x1xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b> + %52 = tt.broadcast %50 : tensor<1x{N}xi32, #dot_b> -> tensor<{M}x{N}xi32, #dot_b> + %53 = arith.addi %51, %52 : tensor<{M}x{N}xi32, #dot_b> + + %54 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b> + %55 = tt.addptr %54, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b> + %56 = tt.load %55 {{triton_intel_gpu.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b> + %57 = tt.splat %arg3 : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b> + %58 = tt.addptr %57, %53 : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b>, tensor<{M}x{N}xi32, #dot_b> + tt.store %58, %56 {{boundaryCheck = array}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #dot_b> + + tt.return + }} + }} + """ + + torch_dtype = getattr(torch, dtype_str) + if torch_dtype.is_floating_point: + a = torch.randn((M, N), dtype=torch_dtype, device=device) + else: + a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device) + + x = torch.empty_like(a) + y = torch.empty_like(a) + + temp_file = tmp_path / "test_tensor_pointer_block_load.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + if support_block_io: + # assert '2d block io' in kernel.asm['llir'] + pass + + kernel[(1, 1, 1)](a, x, a.stride(0), a, y, a.stride(0)) + + assert torch.equal(a, x) and torch.equal(a, y) diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 466ff04228..3aaa82d0c9 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -676,6 +676,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: [[CST_0:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement [[BCAST0]], [[VEC1]][[[CST_0]] : i32] : vector<1xf32> // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to i32 + // CHECK-NEXT: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1 // CHECK-NEXT: [[AND1:%.*]] = llvm.and {{.*}}, [[ARG2_0]] : i1 // CHECK-NEXT: [[VEC2:%.*]] = llvm.mlir.undef : vector<1xi32> // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 @@ -1059,17 +1060,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_cas_f32_scalar_no_store tt.func @atomic_cas_f32_scalar_no_store(%ptr : !tt.ptr, %cmp : f32, %val : f32) { - // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[CMP0:%.*]] = llvm.icmp "eq" - // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]] - // CHECK: [[CMP:%.*]] = llvm.icmp "eq" - // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]] - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKLANE:%.*]] = llvm.and + // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]] + // CHECK: [[MASKWARP:%.*]] = llvm.and + // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]] + // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]] + // CHECK: llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKBLOCK:%.*]] = llvm.and + // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]] + // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]] + // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32 // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () - // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32) + // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32) // CHECK-NEXT: ^bb1: // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32 @@ -1089,13 +1096,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp // CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return} // CHECK-LABEL: atomic_cas_f32_scalar tt.func @atomic_cas_f32_scalar(%ptr : !tt.ptr, %cmp : f32, %val : f32) { - // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[CMP0:%.*]] = llvm.icmp "eq" - // CHECK: [[MASK0:%.*]] = llvm.and [[TRUE]], [[CMP0]] - // CHECK: [[CMP:%.*]] = llvm.icmp "eq" - // CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]] - // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32) + // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[TRUE:%.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKLANE:%.*]] = llvm.and + // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]] + // CHECK: [[MASKWARP:%.*]] = llvm.and + // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]] + // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]] + // CHECK: llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKBLOCK:%.*]] = llvm.and + // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]] + // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]] + // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO1]] : i32) // CHECK-NEXT: ^bb1: // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32 // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast %arg2 : f32 to i32 @@ -1131,14 +1144,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: [[EV1_ARG2:%.*]] = llvm.extractvalue %arg2[1] : !llvm.struct<(f32, f32)> // CHECK: [[EV0_ARG0:%.*]] = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr<1>, ptr<1>)> // CHECK-NEXT: [[EV1_ARG0:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)> - // CHECK: llvm.mlir.constant(true) : i1 - // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[PRED0:%.*]] = llvm.and [[CST_TRUE]], {{.*}} : i1 - // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> + // CHECK: [[EV0_ARG1:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(i1, i1)> + // CHECK-NEXT: [[EV1_ARG1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(i1, i1)> + // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> // CHECK: [[IE1:%.*]] = llvm.insertelement [[EV0_ARG2]], [[UNDEF1]][{{.*}} : i64] : vector<1xf32> - // CHECK-NEXT: [[PRED1:%.*]] = llvm.and [[PRED0]], {{.*}} : i1 // CHECK-NEXT: [[ZERO1:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 - // CHECK: llvm.cond_br [[PRED1]], ^bb1, ^bb2([[ZERO1]] : f32) + // CHECK: llvm.cond_br [[EV0_ARG1]], ^bb1, ^bb2([[ZERO1]] : f32) // CHECK-NEXT: ^bb1: // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32 // CHECK-NEXT: [[RMW_RES1:%.*]] = llvm.atomicrmw fadd [[EV0_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32 @@ -1147,13 +1158,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI1]] : f32 to f32 // CHECK-NEXT: [[UNDEF2:%.*]] = llvm.mlir.undef : vector<1xf32> // CHECK: [[IE2:%.*]] = llvm.insertelement [[EV1_ARG2]], [[UNDEF2]][{{.*}} : i64] : vector<1xf32> - // CHECK-NEXT: [[PRED2:%.*]] = llvm.and [[PRED0]], {{.*}} : i1 // CHECK-NEXT: [[ZERO2:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32 // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () - // CHECK-NEXT: llvm.cond_br [[PRED2]], ^bb3, ^bb4([[ZERO2]] : f32) + // CHECK-NEXT: llvm.cond_br [[EV1_ARG1]], ^bb3, ^bb4([[ZERO2]] : f32) // CHECK-NEXT: ^bb3: // CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32 // CHECK-NEXT: [[RMW_RES2:%.*]] = llvm.atomicrmw fadd [[EV1_ARG0]], [[BCAST2]] monotonic : !llvm.ptr<1>, f32 @@ -1169,14 +1179,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar_no_store tt.func @atomic_add_f32_scalar_no_store(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { - // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[CMP:%.*]] = llvm.icmp "eq" - // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1 - // CHECK: [[CMP1:%.*]] = llvm.icmp "eq" - // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1 - // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> + // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[MASKLANE:%.*]] = llvm.and + // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]] + // CHECK: [[MASKWARP:%.*]] = llvm.and + // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]] + // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]] + // CHECK: llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKBLOCK:%.*]] = llvm.and + // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]] + // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]] + // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> // CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32> - // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1 + // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32 @@ -1200,14 +1215,19 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp // CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return} // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { - // CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[CMP:%.*]] = llvm.icmp "eq" - // CHECK-NEXT: [[AND:%.*]] = llvm.and [[CST_TRUE]], [[CMP]] : i1 - // CHECK: [[CMP1:%.*]] = llvm.icmp "eq" - // CHECK-NEXT: [[AND1:%.*]] = llvm.and [[AND]], [[CMP1]] : i1 - // CHECK: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> + // CHECK: [[ZERO0:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[MASKLANE:%.*]] = llvm.and + // CHECK-NEXT: [[CMPLANE:%.*]] = llvm.icmp "eq" [[MASKLANE]], [[ZERO0]] + // CHECK: [[MASKWARP:%.*]] = llvm.and + // CHECK-NEXT: [[CMPWARP:%.*]] = llvm.icmp "eq" [[MASKWARP]], [[ZERO0]] + // CHECK-NEXT: [[MASKWARPANDLANE:%.*]] = llvm.and [[CMPLANE]], [[CMPWARP]] + // CHECK: llvm.mlir.constant(-1 : i32) : i32 + // CHECK: [[MASKBLOCK:%.*]] = llvm.and + // CHECK-NEXT: [[CMPBLOCK:%.*]] = llvm.icmp "eq" [[MASKBLOCK]], [[ZERO0]] + // CHECK-NEXT: [[MASK:%.*]] = llvm.and [[MASKWARPANDLANE]], [[CMPBLOCK]] + // CHECK-NEXT: [[UNDEF1:%.*]] = llvm.mlir.undef : vector<1xf32> // CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32> - // CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1 + // CHECK: [[PRED:%.*]] = llvm.and %arg1, [[MASK]] : i1 // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32) // CHECK-NEXT: ^bb1: @@ -1295,22 +1315,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: [[ARG0_1:%.*]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<1>, ptr<1>)> // CHECK-NEXT: [[ARG1_0:%.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(f32, f32)> // CHECK-NEXT: [[ARG1_1:%.*]] = llvm.extractvalue %arg1[1] : !llvm.struct<(f32, f32)> - // CHECK: llvm.mlir.constant(true) : i1 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO]]) {{.*}} : (i32) -> i64 - // CHECK: [[TRUE1:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[TRUE2:%.*]] = llvm.mlir.constant(true) : i1 - // CHECK: [[PRED:%.*]] = llvm.and [[TRUE1]], [[TRUE2]] : i1 + // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: llvm.call spir_funccc @_Z12get_local_idj([[ZERO1]]) {{.*}} : (i32) -> i64 + // CHECK: [[PRED:%.*]] = llvm.mlir.constant(true) : i1 // CHECK: llvm.cond_br [[PRED]], ^bb1, ^bb2 // CHECK-NEXT: ^bb1: // CHECK-NEXT: [[BCAST:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1> // CHECK-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1> // CHECK-NEXT: llvm.br ^bb2 // CHECK-NEXT: ^bb2: + // CHECK: llvm.mlir.undef : vector<1xf32> + // CHECK: [[PRED2:%.*]] = llvm.mlir.constant(true) : i1 // CHECK: [[VEC:%.*]] = llvm.mlir.undef : vector<1xi32> // CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: [[IE1:%.*]] = llvm.insertelement {{.*}}, [[VEC]][[[ZERO]] : i32] : vector<1xi32> - // CHECK: llvm.cond_br [[PRED]], ^bb3, ^bb4 + // CHECK: llvm.cond_br [[PRED2]], ^bb3, ^bb4 // CHECK-NEXT: ^bb3: // CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1> // CHECK-NEXT: llvm.store [[IE1]], [[BCAST1]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1> diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index ae57d4d83d..0489e12a56 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -362,7 +362,9 @@ def make_llir(src, metadata, options): llvm_mod = llvm.to_module(mod, context) intel.set_spv_target_triple(llvm_mod) if os.getenv("TRITON_INTEL_FAST_MATH", "0") == "1": - intel.set_fast_math(llvm_mod) + intel.set_fast_math(llvm_mod, True) + else: + intel.set_fast_math(llvm_mod, False) if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index f6a1ca913a..bd8f0d6c08 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -110,6 +110,39 @@ loadCacheControlToCacheControls(Builder &builder, return builder.getAttr(decorations); } +[[maybe_unused]] static bool +isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) { + VectorType resTy = op.getRes().getType(); + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + bool needsResElemSizeEqualTo32 = + op.getElemSizeInBits() == 32 || op.getVnniTransform(); + assert((!needsResElemSizeEqualTo32 || resElemTySize == 32) && + "Expecting 32-bit element type"); + if (!needsResElemSizeEqualTo32 && resElemTySize != 16) + return false; + + if (op.getVnniTransform()) + return true; + + if (op.getTranspose() && op.getTileHeight() != 16) + return false; + + uint32_t tileWidth = op.getTileWidth(); + uint32_t tileHeight = op.getTileHeight(); + switch (op.getElemSizeInBits()) { + case 8: + return (tileWidth == 32); + case 16: + return (tileWidth == 16) && (tileHeight != 32); + case 32: + return (tileWidth == 8 || tileWidth == 16) && (tileHeight != 32); + default: + llvm_unreachable("unexpected element size"); + } + + return false; +} + [[maybe_unused]] static Value createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, ConversionPatternRewriter &rewriter) { @@ -119,12 +152,20 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, auto b = TritonLLVMOpBuilder(loc, rewriter); Value ptr = op.getPtr(); - Value baseWidth = op.getBaseWidth(); Value baseHeight = op.getBaseHeight(); Value basePitch = op.getBasePitch(); - Value x = op.getX(); Value y = op.getY(); + // compensate the non-64 byte aligned base. + Value offset = + b.trunc(i32_ty, b.and_(b.ptrtoint(i64_ty, ptr), b.i64_val(0x3f))); + // In number of bytes. + Value baseWidth = b.add(op.getBaseWidth(), offset); + // In number of scalar elements. + Value offsetX = + b.add(op.getX(), + b.lshr(offset, b.i32_val(std::log2(op.getElemSizeInBits() / 8)))); + std::string funcName = "llvm.genx.GenISA.LSC2DBlockRead." + getGenISATypeMangling(resType); IntegerType int1Ty = rewriter.getIntegerType(1); @@ -139,7 +180,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, baseWidth.getType(), baseHeight.getType(), basePitch.getType(), - x.getType(), + offsetX.getType(), y.getType(), int32Ty, int32Ty, @@ -153,7 +194,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op, b.sub(baseWidth, one), b.sub(baseHeight, one), b.sub(basePitch, one), - x, + offsetX, y, b.i32_val(op.getElemSizeInBits()), b.i32_val(op.getTileWidth()), @@ -421,8 +462,9 @@ struct TritonMatrix2DBlockLoadLowering LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getElemSizeInBits() == 8 && op.getTileWidth() == 16 && - op.getVBlocks() != 4 && !op.getVnniTransform()) { + if (!isOCLBuiltinAvailable(op) || + op.getElemSizeInBits() == 8 && op.getTileWidth() == 16 && + op.getVBlocks() != 4 && !op.getVnniTransform()) { // TODO: add ocl builtin/spirv intrinsics for 8b 16 column 1 vBlock & 2 // vBlock reads rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter)); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 601d34f0d7..3a51103d60 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -23,67 +23,44 @@ using namespace mlir::triton::gpu::intel; namespace { -// Return the mask for the unique data accessed by given tensor type. -// Used to mask out the redundant data accessed by threads. -Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, - Location loc, - const triton::intel::TargetInfo &targetInfo) { +Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { + auto tb = TritonLLVMOpBuilder(loc, rewriter); + if (a && b) { + return tb.and_(a, b); + } + return a ? a : b; +} + +// Return a predicate that is true only if the current thread holds unique data, +// according to freeVarsMask. The predicate may be null to indicate no +// predication is required. +Value emitRedundantThreadPredicate( + const llvm::MapVector &freeVarMasks, + ConversionPatternRewriter &rewriter, Location loc, + const triton::intel::TargetInfo &targetInfo) { auto b = TritonLLVMOpBuilder(loc, rewriter); - auto tensorTy = dyn_cast(valueTy); - Value mask = b.true_val(); - auto tid = getThreadId(rewriter, loc); - auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); - if (tensorTy) { - // To remove this use, port https://github.com/triton-lang/triton/pull/5432 - // to the INTELGPU dialect - auto layout = cast(tensorTy.getEncoding()); - auto shape = tensorTy.getShape(); - auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); - auto kLane = StringAttr::get(rewriter.getContext(), "lane"); - auto kWarp = StringAttr::get(rewriter.getContext(), "warp"); - auto maskLane = - std::get<1>(delinearize(rewriter, loc, layout, shape, kLane, laneId)); - auto maskWarp = - std::get<1>(delinearize(rewriter, loc, layout, shape, kWarp, warpId)); - mask = b.and_(maskLane, maskWarp); - - // Do not write duplicated data when multicast is enabled - if (triton::gpu::getNumCTAs(layout) > 1) { - auto _0 = b.i32_val(0); - auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); - auto CTASplitNum = triton::gpu::getCTASplitNum(layout); - auto CTAOrder = triton::gpu::getCTAOrder(layout); - - auto multiDimClusterCTAId = - delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); - - auto rank = tensorTy.getRank(); - for (unsigned dim = 0; dim < rank; ++dim) { - // Skip when multicast is not enabled in this dimension - if (CTAsPerCGA[dim] == CTASplitNum[dim]) - continue; - // This wrapping rule must be consistent with emitCTAOffsetForLayout - unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); - Value repId = b.udiv(multiDimClusterCTAId[dim], b.i32_val(splitNum)); - // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: - // CTA0 and CTA2 holds data of block0, - // CTA1 and CTA3 holds data of block1. - // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should - // be masked. We add the following mask: - // multiDimClusterCTAId[dim] / splitNum == 0 - // Actually in all existing cases of multicast, splitNum is always 1. - // The mask is equivalent to: - // multiDimClusterCTAId[dim] == 0 - mask = b.and_(mask, b.icmp_eq(repId, _0)); - } + auto ctx = rewriter.getContext(); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value zero = b.i32_val(0); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = freeVarMasks.lookup(kBlock) == 0 + ? zero + : targetInfo.getClusterCTAId(rewriter, loc); + + Value pred; + auto dimNames = {kLane, kWarp, kBlock}; + auto dimIds = {laneId, warpId, blockId}; + for (auto [dimName, dimId] : llvm::zip(dimNames, dimIds)) { + int32_t mask = freeVarMasks.lookup(dimName); + if (mask != 0) { + auto dimPred = b.icmp_eq(b.and_(dimId, b.i32_val(mask)), zero); + pred = maybeAnd(rewriter, loc, pred, dimPred); } - } else { - // If the tensor is not ranked, then it is a scalar and only thread 0 of - // CTA0 can write - mask = b.and_(mask, b.icmp_eq(clusterCTAId, b.i32_val(0))); - mask = b.and_(mask, b.icmp_eq(tid, b.i32_val(0))); } - return mask; + return pred; } /// Holds the values related to a block pointer. @@ -494,6 +471,11 @@ struct LoadOpToBlockIOConversion LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { + ModuleOp mod = op->getParentOfType(); + if (!mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: + getSupportSG2DBlockAttrName())) + return failure(); + Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); if (!blockIOAttr) @@ -727,6 +709,10 @@ struct LoadOpToBlockIOConversion if (otherElems.size()) others[offset] = otherElems[i]; } + // ptrs[{0, 0}] and ptrs[{1, 0}] are currently used to calculate the + // pitch. + if (ptrs.count({0, 0}) < 1 || ptrs.count({1, 0}) < 1) + return failure(); } unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN; @@ -769,6 +755,8 @@ struct LoadOpToBlockIOConversion // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the vBlocks. unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; + if (totalBytesPerRowPerDPASOp > 64) + return failure(); numOperandsPer2DloadN = std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp); @@ -815,6 +803,8 @@ struct LoadOpToBlockIOConversion StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); + const unsigned originalElemBits = elemSizeInBits; + ValueTable loadVals; for (int inner = 0; inner < numRepInner; inner += numOperandsInnerDimPerLoad) { @@ -884,9 +874,9 @@ struct LoadOpToBlockIOConversion /*tile_height*/ tileHeight, /*v_blocks*/ vBlocks, /*transpose*/ false, - /*vnni_transform*/ opIdx == - DpasEncodingAttr::OpIdx::OperandB && - usePackedType); + /*vnni_transform*/ + (usePackedType && !isOperandA && !isTransposeRequired && + originalElemBits != 32)); return SmallVector{load2dOp}; }); Value ret = *endBlock.args_begin(); @@ -2115,7 +2105,6 @@ struct StoreOpConversion auto *typeConverter = getTypeConverter(); MLIRContext *ctx = rewriter.getContext(); Value ptr = op.getPtr(); - Value mask = op.getMask(); Value llMask = adaptor.getMask(); // Determine the vectorization size @@ -2125,7 +2114,7 @@ struct StoreOpConversion SmallVector ptrElems, maskElems; unsigned vec = getVectorSize(ptr); if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); + vec = std::min(vec, getMaskAlignment(op.getMask())); if (isTensorPointerType(ptr.getType())) { // fallback to scatter store. @@ -2147,7 +2136,11 @@ struct StoreOpConversion assert(!maskElems.size() || valueElems.size() == maskElems.size() && "Mask size mismatch"); - mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + auto freeVarMasks = getFreeVariableMasks(valueTy); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); + uint32_t regMask = freeVarMasks[str_attr("register")]; + const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -2155,6 +2148,10 @@ struct StoreOpConversion unsigned elemsPerThread = getTotalElemsPerThread(valueTy); const int numVecs = elemsPerThread / vec; for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { + if (!isCanonicalIndex(vecStart, regMask)) { + // Don't emit store ops for redundant elements within a thread + continue; + } // TODO: optimization when ptr is AddPtr with constant offset size_t in_off = 0; @@ -2192,8 +2189,11 @@ struct StoreOpConversion asmArgs.emplace_back(llWord, constraint); } - Value maskVal = - maskElems.size() ? b.and_(mask, maskElems[vecStart]) : mask; + Value maskVal = threadPred ? threadPred : b.true_val(); + if (llMask) { + auto mask = maskElems[vecStart]; + maskVal = maybeAnd(rewriter, loc, maskVal, mask); + } auto vecTy = vec_ty(valArgTy, nWords); Value vecWord = b.undef(vecTy); @@ -2286,7 +2286,9 @@ struct AtomicCASOpConversion vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + auto freeVarMasks = getFreeVariableMasks(valueTy); + Value mask = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -2314,20 +2316,33 @@ struct AtomicCASOpConversion loc, spirv::Scope::Workgroup, spirv::Scope::Workgroup, spirv::MemorySemantics::SequentiallyConsistent | spirv::MemorySemantics::CrossWorkgroupMemory); - Block &endBlock = - LLVM::intel::createPredicatedBlock(rewriter, loc, mask, {zero}, [&] { - // casPtr = b.bitcast(casPtr, ptr_ty(ctx, 1)); - casCmp = b.bitcast(casCmp, zero.getType()); - casVal = b.bitcast(casVal, zero.getType()); - - auto cmpxchg = rewriter.create( - loc, casPtr, casCmp, casVal, successOrdering, failureOrdering); - Value newLoaded = - rewriter.create(loc, cmpxchg, 0); - return SmallVector{newLoaded}; - }); + Value ret; + // TODO: de-duplicate + if (mask) { + Block &endBlock = LLVM::intel::createPredicatedBlock( + rewriter, loc, mask, {zero}, [&] { + // casPtr = b.bitcast(casPtr, ptr_ty(ctx, 1)); + casCmp = b.bitcast(casCmp, zero.getType()); + casVal = b.bitcast(casVal, zero.getType()); + + auto cmpxchg = rewriter.create( + loc, casPtr, casCmp, casVal, successOrdering, + failureOrdering); + Value newLoaded = + rewriter.create(loc, cmpxchg, 0); + return SmallVector{newLoaded}; + }); + + ret = endBlock.getArgument(0); + } else { + // casPtr = b.bitcast(casPtr, ptr_ty(ctx, 1)); + casCmp = b.bitcast(casCmp, zero.getType()); + casVal = b.bitcast(casVal, zero.getType()); - Value ret = endBlock.getArgument(0); + auto cmpxchg = rewriter.create( + loc, casPtr, casCmp, casVal, successOrdering, failureOrdering); + ret = rewriter.create(loc, cmpxchg, 0); + } Type retType = (!tensorTy || vec == 1) ? valueElemTy : vecTy; ret = b.bitcast(ret, retType); @@ -2424,7 +2439,9 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + auto freeVarMasks = getFreeVariableMasks(valueTy); + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -2437,7 +2454,9 @@ struct AtomicRMWOpConversion } Value rmwPtr = ptrElements[i]; - Value rmwMask = llMask ? b.and_(mask, maskElements[i]) : mask; + Value rmwMask = llMask + ? maybeAnd(rewriter, loc, maskElements[i], threadPred) + : threadPred; assert((valueElemNBits == 16 || valueElemNBits == 32 || valueElemNBits == 64) && @@ -2454,64 +2473,110 @@ struct AtomicRMWOpConversion Block *endBlock = nullptr; // TODO: check device capabilities to avoid unnecessary emulation or // emit unsupported feature error. + Value ret; + if (valueElemNBits == 16) { op.emitWarning( "'tt.atomic_rmw' op fp16 datatype is not supported in the target " "HW, software emulation is an experimental feature (use at own " "risk)"); - endBlock = - emulateFp16AtomicRmw(rewriter, loc, atomicRmwAttr, valueElemTy, - rmwPtr, rmwVal, rmwMask, {zero}); + endBlock = emulateFp16AtomicRmw( + rewriter, loc, atomicRmwAttr, valueElemTy, rmwPtr, rmwVal, + maybeAnd(rewriter, loc, b.true_val(), rmwMask), {zero}); } else { if (!atomicNeedsSharedMemory(op.getResult())) rewriter.create( loc, spirv::Scope::Workgroup, spirv::Scope::Workgroup, spirv::MemorySemantics::SequentiallyConsistent | spirv::MemorySemantics::CrossWorkgroupMemory); - endBlock = &LLVM::intel::createPredicatedBlock( - rewriter, loc, rmwMask, {zero}, [&] { - mlir::LLVM::AtomicBinOp rmwKind; - switch (atomicRmwAttr) { - case RMWOp::AND: - rmwKind = LLVM::AtomicBinOp::_and; - break; - case RMWOp::OR: - rmwKind = LLVM::AtomicBinOp::_or; - break; - case RMWOp::XOR: - rmwKind = LLVM::AtomicBinOp::_xor; - break; - case RMWOp::ADD: - rmwKind = LLVM::AtomicBinOp::add; - break; - case RMWOp::FADD: - rmwKind = LLVM::AtomicBinOp::fadd; - break; - case RMWOp::MAX: - rmwKind = LLVM::AtomicBinOp::max; - break; - case RMWOp::UMAX: - rmwKind = LLVM::AtomicBinOp::umax; - break; - case RMWOp::MIN: - rmwKind = LLVM::AtomicBinOp::min; - break; - case RMWOp::UMIN: - rmwKind = LLVM::AtomicBinOp::umin; - break; - case RMWOp::XCHG: - rmwKind = LLVM::AtomicBinOp::xchg; - break; - } - rmwVal = b.bitcast(rmwVal, valueElemTy); - auto atomRMW = rewriter.create( - loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering); - return SmallVector{atomRMW.getRes()}; - }); + // TODO: de-duplicate + if (rmwMask) { + endBlock = &LLVM::intel::createPredicatedBlock( + rewriter, loc, rmwMask, {zero}, [&] { + mlir::LLVM::AtomicBinOp rmwKind; + switch (atomicRmwAttr) { + case RMWOp::AND: + rmwKind = LLVM::AtomicBinOp::_and; + break; + case RMWOp::OR: + rmwKind = LLVM::AtomicBinOp::_or; + break; + case RMWOp::XOR: + rmwKind = LLVM::AtomicBinOp::_xor; + break; + case RMWOp::ADD: + rmwKind = LLVM::AtomicBinOp::add; + break; + case RMWOp::FADD: + rmwKind = LLVM::AtomicBinOp::fadd; + break; + case RMWOp::MAX: + rmwKind = LLVM::AtomicBinOp::max; + break; + case RMWOp::UMAX: + rmwKind = LLVM::AtomicBinOp::umax; + break; + case RMWOp::MIN: + rmwKind = LLVM::AtomicBinOp::min; + break; + case RMWOp::UMIN: + rmwKind = LLVM::AtomicBinOp::umin; + break; + case RMWOp::XCHG: + rmwKind = LLVM::AtomicBinOp::xchg; + break; + } + + rmwVal = b.bitcast(rmwVal, valueElemTy); + auto atomRMW = rewriter.create( + loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering); + return SmallVector{atomRMW.getRes()}; + }); + } else { + mlir::LLVM::AtomicBinOp rmwKind; + switch (atomicRmwAttr) { + case RMWOp::AND: + rmwKind = LLVM::AtomicBinOp::_and; + break; + case RMWOp::OR: + rmwKind = LLVM::AtomicBinOp::_or; + break; + case RMWOp::XOR: + rmwKind = LLVM::AtomicBinOp::_xor; + break; + case RMWOp::ADD: + rmwKind = LLVM::AtomicBinOp::add; + break; + case RMWOp::FADD: + rmwKind = LLVM::AtomicBinOp::fadd; + break; + case RMWOp::MAX: + rmwKind = LLVM::AtomicBinOp::max; + break; + case RMWOp::UMAX: + rmwKind = LLVM::AtomicBinOp::umax; + break; + case RMWOp::MIN: + rmwKind = LLVM::AtomicBinOp::min; + break; + case RMWOp::UMIN: + rmwKind = LLVM::AtomicBinOp::umin; + break; + case RMWOp::XCHG: + rmwKind = LLVM::AtomicBinOp::xchg; + break; + } + + rmwVal = b.bitcast(rmwVal, valueElemTy); + auto atomRMW = rewriter.create( + loc, rmwKind, rmwPtr, rmwVal, llvmMemOrdering); + ret = atomRMW.getResult(); + } } - Value ret = endBlock->getArgument(0); + ret = endBlock ? endBlock->getArgument(0) : ret; + assert(ret); Type retType = (!tensorTy || vec == 1) ? valueElemTy : vecTy; ret = b.bitcast(ret, retType); diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index b73357548c..7b331af8e8 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -265,13 +265,13 @@ void init_triton_intel(py::module &&m) { // producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the // fast math semantics on all arithmetic operations. // https://github.com/intel/intel-xpu-backend-for-triton/issues/3862 - m.def("set_fast_math", [](llvm::Module *mod) { + m.def("set_fast_math", [](llvm::Module *mod, bool flag) { using namespace llvm; for (Function &func : *mod) { for (Instruction &inst : instructions(func)) { if (auto *op = dyn_cast(&inst)) { FastMathFlags FMF; - FMF.setFast(true); + FMF.setFast(flag); inst.setFastMathFlags(FMF); } }