Skip to content

Commit 2c78859

Browse files
committed
Support global scratch in intel launcher
1 parent 6f851e2 commit 2c78859

20 files changed

+70
-52
lines changed

python/triton/tools/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def constexpr(s):
183183
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
184184
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
185185
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
186-
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]),
186+
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
187187
"arg_types": ", ".join(ty_to_cpp(arg) for arg in arg_types_not_1),
188188
"num_args": len(arg_names_not_1),
189189
"kernel_docstring": doc_string,

test/Conversion/intel/arith_to_llvm.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
// CHECK-SCALAR-DAG: llvm.func spir_funccc @_Z27__spirv_ConvertFToBF16INTELf(f32) -> i16 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
77

88
// CHECK-LABEL: llvm.func spir_kernelcc @float_to_bfloat_conversion(
9-
// CHECK-SCALAR: %[[VAL_0:.*]]: !llvm.struct<(f32, f32, f32, f32)>) -> !llvm.struct<(bf16, bf16, bf16, bf16)>
9+
// CHECK-SCALAR: %[[VAL_0:.*]]: !llvm.struct<(f32, f32, f32, f32)>,
10+
// CHECK-SCALAR: %[[PTR_1:.*]]: !llvm.ptr<1>) -> !llvm.struct<(bf16, bf16, bf16, bf16)>
1011
// CHECK-VECTOR: %[[VAL_0:.*]]: vector<32xf32>) -> vector<32xbf16>
1112
module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.support_dpas", "triton_intel_gpu.support_bf16_conversion", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1213
tt.func @float_to_bfloat_conversion(%arg0 : tensor<512xf32, #blocked>) -> tensor<512xbf16, #blocked>{
@@ -35,7 +36,8 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
3536
}
3637

3738
// CHECK-LABEL: llvm.func spir_kernelcc @bfloat_to_float_conversion(
38-
// CHECK-SCALAR: %[[VAL_0:.*]]: !llvm.struct<(bf16, bf16, bf16, bf16)>) -> !llvm.struct<(f32, f32, f32, f32)>
39+
// CHECK-SCALAR: %[[VAL_0:.*]]: !llvm.struct<(bf16, bf16, bf16, bf16)>,
40+
// CHECK-SCALAR: %[[PTR_1:.*]]: !llvm.ptr<1>) -> !llvm.struct<(f32, f32, f32, f32)>
3941
tt.func @bfloat_to_float_conversion(%arg0 : tensor<512xbf16, #blocked>) -> tensor<512xf32, #blocked>{
4042
// CHECK-SCALAR: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(bf16, bf16, bf16, bf16)>
4143
// CHECK-SCALAR: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(bf16, bf16, bf16, bf16)>

test/Conversion/intel/dot_layout_offset.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#dpas = #triton_intel_gpu.dpas<{repeatCount=8, systolicDepth=8, executionSize = 8, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA=[1, 1], repCluster=[2, 2]}>
44
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
55
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
6-
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset()
6+
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset(%arg0: !llvm.ptr<1>)
77
tt.func public @dot_layout_emit_offset() {
88
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot_operand_a>
99
// CHECK-COUNT-64: {{.*}} = llvm.extractvalue {{.*}}
@@ -324,7 +324,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32}
324324
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
325325
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
326326

327-
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset()
327+
// CHECK-LABEL: llvm.func spir_kernelcc @dot_layout_emit_offset(%arg0: !llvm.ptr<1>)
328328
tt.func public @dot_layout_emit_offset() {
329329
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #dot_operand_b>
330330
// CHECK-COUNT-64: {{.*}} = llvm.extractvalue {{.*}}

test/Conversion/intel/glue.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.support_dpas", "ttg.num-warps" = 4 : i32} {
55
// CHECK-LABEL: llvm.func spir_kernelcc @test_scalar(
6-
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32) -> vector<4xf32>
6+
// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32, %[[PTR_1:.*]]: !llvm.ptr<1>) -> vector<4xf32>
77
// CHECK: %[[VAL_8:.*]] = llvm.mlir.poison : vector<4xf32>
88
// CHECK: %[[VAL_9:.*]] = llvm.mlir.constant(0 : i32) : i32
99
// CHECK: %[[VAL_10:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_8]]{{\[}}%[[VAL_9]] : i32] : vector<4xf32>
@@ -21,7 +21,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
2121
}
2222

2323
// CHECK-LABEL: llvm.func spir_kernelcc @test_vec_2(
24-
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>, %[[VAL_1:.*]]: vector<4xf32>) -> vector<8xf32>
24+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>, %[[VAL_1:.*]]: vector<4xf32>, %[[PTR_1:.*]]: !llvm.ptr<1>) -> vector<8xf32>
2525
// CHECK: %[[VAL_4:.*]] = llvm.shufflevector %[[VAL_0]], %[[VAL_1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>
2626
// CHECK: llvm.return %[[VAL_4]] : vector<8xf32>
2727
// CHECK: }
@@ -31,7 +31,7 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
3131
}
3232

3333
// CHECK-LABEL: llvm.func spir_kernelcc @test_vec_4(
34-
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>, %[[VAL_1:.*]]: vector<4xf32>, %[[VAL_2:.*]]: vector<4xf32>, %[[VAL_3:.*]]: vector<4xf32>) -> vector<16xf32>
34+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>, %[[VAL_1:.*]]: vector<4xf32>, %[[VAL_2:.*]]: vector<4xf32>, %[[VAL_3:.*]]: vector<4xf32>, %[[PTR_1:.*]]: !llvm.ptr<1>) -> vector<16xf32>
3535
// CHECK: %[[VAL_8:.*]] = llvm.shufflevector %[[VAL_0]], %[[VAL_1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>
3636
// CHECK: %[[VAL_9:.*]] = llvm.shufflevector %[[VAL_2]], %[[VAL_3]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>
3737
// CHECK: %[[VAL_10:.*]] = llvm.shufflevector %[[VAL_8]], %[[VAL_9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>

test/Conversion/intel/shared_to_dot_layout_convert.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
1010
// CHECK-LABEL: llvm.func spir_kernelcc @convert_dot(
11-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>)
11+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
12+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
1213
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 16 : i32, {{.*}}} {
1314
tt.func @convert_dot(%A: tensor<128x64xf16, #blocked0>) {
1415
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
@@ -44,7 +45,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
4445

4546
module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
4647
// CHECK-LABEL: llvm.func spir_kernelcc @convert_dot(
47-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>)
48+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
49+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
4850
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 16 : i32, {{.*}}} {
4951
tt.func @convert_dot(%A: tensor<128x64xf16, #blocked0>) {
5052
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
@@ -81,7 +83,8 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
8183

8284
module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
8385
// CHECK-LABEL: llvm.func spir_kernelcc @convert_dot(
84-
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>)
86+
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
87+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
8588
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 16 : i32, {{.*}}} {
8689
tt.func @convert_dot(%B: tensor<64x256xf16, #blocked1>) {
8790
// CHECK-DAG: %[[CST_128:.*]] = llvm.mlir.constant(128 : i32) : i32

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
1111
// CHECK-LABEL: llvm.func spir_kernelcc @test_f16(
12-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>)
12+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>
13+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
1314
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16)>
1415
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
1516
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_4]])
@@ -49,7 +50,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
4950
}
5051

5152
// CHECK-LABEL: llvm.func spir_kernelcc @test_bf16(
52-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>)
53+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>
54+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
5355
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(bf16)>
5456
// CHECK: %[[VAL_2:.*]] = llvm.bitcast %[[VAL_1]] : bf16 to i16
5557
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -91,7 +93,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
9193
}
9294

9395
// CHECK-LABEL: llvm.func spir_kernelcc @test_i1(
94-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>)
96+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>
97+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
9598
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i1)>
9699
// CHECK: %[[VAL_2:.*]] = llvm.zext %[[VAL_1]] : i1 to i8
97100
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -133,7 +136,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
133136
}
134137

135138
// CHECK-LABEL: llvm.func spir_kernelcc @test_ptr(
136-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>)
139+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>
140+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
137141
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(ptr<1>)>
138142
// CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr<1> to i64
139143
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -186,7 +190,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
186190

187191
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
188192
// CHECK-LABEL: llvm.func spir_kernelcc @test_f32(
189-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>)
193+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>
194+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
190195
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32)>
191196
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
192197
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_4]])
@@ -269,7 +274,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
269274

270275
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
271276
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register(
272-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>)
277+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>
278+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
273279
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274280
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275281
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -370,7 +376,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.thr
370376

371377
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
372378
// CHECK-LABEL: llvm.func spir_kernelcc @test_contiguous(
373-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>)
379+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16, f16)>
380+
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>
374381
tt.func @test_contiguous(%arg0: tensor<32xf16, #ttg.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #blocked1}>> {
375382
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16, f16)>
376383
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f16, f16)>

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
4-
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
4+
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>)
55
// Here the 128 comes from the 4 in module attribute multiples 32
66
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
77
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
7373

7474
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
7575
// CHECK-LABEL: llvm.func spir_kernelcc @dot_f32_tf32_tf32_f32_1(
76-
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f32, f32, f32, f32)>, %[[B:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>,
77-
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 32, 1, 1>} {
76+
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f32, f32, f32, f32)>, %[[B:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 32, 1, 1>} {
7877
tt.func @dot_f32_tf32_tf32_f32_1(%a: tensor<8x8xf32, #dot_operand_a>, %b: tensor<8x16xf32, #dot_operand_b>, %c: tensor<8x16xf32, #dpas>) {
7978
// COM: To simplify, only check RTNE and its usage for the last element of A, B, C
8079
// CHECK: %[[A_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
@@ -117,7 +116,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
117116
// CHECK: llvm.func spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
118117
// CHECK-LABEL: llvm.func spir_kernelcc @dot_rep_cluster_4_2(
119118
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>, %[[B:.*]]: !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>,
120-
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 16, 1, 1>} {
119+
// CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 16, 1, 1>} {
121120
tt.func @dot_rep_cluster_4_2(%a: tensor<32x32xf16, #dot_operand_a>, %b: tensor<32x32xf16, #dot_operand_b>, %c: tensor<32x32xf32, #dpas>) {
122121
// CHECK: %[[VAL_3:.*]] = llvm.mlir.undef : vector<8xf32>
123122
// CHECK: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32

0 commit comments

Comments
 (0)