Skip to content

Commit d9fffb7

Browse files
authored
Separate TTGPU to LLVM and TritonGEN to LLVM Lowering (#4276)
Changes the Intel lowering pipeline during `make_llir` to first convert TritonIntelGPU to LLVM Dialect, then convert TritonGEN to LLVM Dialect. This is analogous to the NVIDIA backend, allows us to run the GEN verifier on all GEN ops, and allows us to write lit tests against TTGPUIR looking for specific gen ops - e.g. 2d block loads with specific tile sizes. To lower GEN to LLVM Dialect we need to add the GEN To SPIRV and SPIRV to LLVM patterns to the lowering step. I chose to do that in the existing TritonGENToLLVM pass, but we could introduce a new pass too (`TritonGENToLLVMViaSPIRV`?). Finally, I had to disable the tensor of pointer -> 2d block io lit test. The test is failing Gen validation because the pitch passed to the TritonGEN op is 0. This likely wouldn't work in practice, but it is now being caught during lit testing because we are running the verifier. I have opened a separate issue to fix this: #4275 I also discovered several places where the DPAS layout was using the incorrect `opsPerChannel` parameter for the data type - I had initially opened #4270 to track this, but fixed the issue in this PR. However, we should consider layout validation to ensure DPAS type-specific parameters match the data type. close #4269
1 parent 2a9e284 commit d9fffb7

File tree

16 files changed

+53
-34
lines changed

16 files changed

+53
-34
lines changed

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s
22

33
// Basic 16x16 transpose test
44

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
44
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
@@ -1026,11 +1026,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10261026

10271027
// -----
10281028

1029-
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
1029+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
10301030
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
10311031
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
10321032
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
1033-
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
1033+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=1}>
10341034
#smem = #ttg.shared_memory
10351035
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10361036
// CHECK-LABEL: matmul_tf32dot
@@ -1040,7 +1040,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10401040
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
10411041
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
10421042

1043-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
1043+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_fDv8_fS0_i(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<4xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
10441044
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
10451045
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
10461046

@@ -1485,11 +1485,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
14851485

14861486
// -----
14871487

1488-
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
1488+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
14891489
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
14901490
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
14911491
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
1492-
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
1492+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=1}>
14931493
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14941494
// CHECK-LABEL: matmul_tf32_cst_b
14951495
tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},

test/Conversion/intel/tritongpu_to_gen_dot.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,NO-AGGRESSIVE-REUSE
2-
// RUN: env TRITON_INTEL_AGGRESSIVE_DPAS_REUSE=1 triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,AGGRESSIVE-REUSE
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,NO-AGGRESSIVE-REUSE
2+
// RUN: env TRITON_INTEL_AGGRESSIVE_DPAS_REUSE=1 triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse -canonicalize | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,AGGRESSIVE-REUSE
33

44
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1]}>
55
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --split-input-file | FileCheck %s
1+
// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --split-input-file | FileCheck %s
22

33
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32} {
44
// CHECK-DAG: llvm.func spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm=one_matrix_per_load_for_bt=1 --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
33

44
// CHECK-DAG: llvm.func spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
55
// CHECK-DAG: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return}
44
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>

test/TritonIntelGPU/prefetch-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --cse -canonicalize | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse -canonicalize | FileCheck %s
22

33
// CHECK: llvm.func spir_funccc @_Z36__spirv_Subgroup2DBlockPrefetchINTELiiiiPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
44
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --cse | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --cse | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
// CHECK: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv
44
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def make_llir(src, metadata, options):
327327
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
328328
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt,
329329
options.enable_tile_load_linear_layout)
330+
intel.passes.ttgpuir.add_gen_to_llvm(pm)
330331
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)
331332
passes.common.add_canonicalizer(pm)
332333
passes.common.add_cse(pm)

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def TritonIntelGPU_Dialect : Dialect {
1616
"triton::TritonDialect",
1717
"triton::gpu::TritonGPUDialect",
1818
"mlir::gpu::GPUDialect",
19+
"mlir::triton::TritonGEN::TritonGENDialect",
1920
];
2021

2122
let extraClassDeclaration = [{

third_party/intel/include/TritonGENToLLVM/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def ConvertTritonGENToLLVM : Pass<"convert-tritongen-to-llvm"> {
14+
def ConvertTritonGENToLLVM : Pass<"convert-tritongen-to-llvm", "mlir::ModuleOp"> {
1515
let summary = "Convert the Triton GEN dialect to the LLVM dialect";
1616
let description = [{
1717
This pass converts the TritonGEN dialect operations to LLVM dialect operations.
1818
}];
19+
1920
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
2021
}
2122

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1414
#include "mlir/Conversion/LLVMCommon/Pattern.h"
15+
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
1516
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1617
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -37,6 +38,7 @@
3738

3839
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
3940
#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
41+
#include "intel/include/TritonGENToSPIRV/TritonGENToSPIRVPass.h"
4042

4143
namespace mlir::triton {
4244
#define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM
@@ -818,15 +820,19 @@ struct ConvertTritonGENToLLVM
818820

819821
void runOnOperation() override {
820822
MLIRContext *ctx = &getContext();
821-
RewritePatternSet pattern(ctx);
823+
RewritePatternSet patterns(ctx);
822824
LowerToLLVMOptions options(ctx);
823-
LLVMTypeConverter converter(ctx, options);
825+
LLVMTypeConverter typeConverter(ctx, options);
824826
LLVMConversionTarget target(*ctx);
825827

826-
populateTritonGENToLLVMConversionPatterns(converter, pattern);
828+
populateTritonGENToLLVMConversionPatterns(typeConverter, patterns);
829+
830+
populateTritonGENToSPIRVConversionPatterns(patterns);
831+
populateSPIRVToLLVMConversionPatterns(typeConverter, patterns,
832+
spirv::ClientAPI::OpenCL);
827833

828-
if (failed(
829-
applyPartialConversion(getOperation(), target, std::move(pattern))))
834+
if (failed(applyPartialConversion(getOperation(), target,
835+
std::move(patterns))))
830836
signalPassFailure();
831837
}
832838
};

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,10 @@ struct PrefetchOpConversion
557557
/*v_blocks*/ vBlocks,
558558
/*cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
559559
if (failed(newOp.verify())) {
560-
// Explicitly invoke verifier because `triton_gen` ops are immediately
561-
// lowered further to a builtin call.
560+
// delete the op so that the verifier will not abort the pass
561+
// pipeline later, as we can fail this path and try a different
562+
// approach.
563+
rewriter.eraseOp(newOp);
562564
return failure();
563565
}
564566
}
@@ -757,8 +759,10 @@ struct PrefetchOpConversion
757759
/*v_blocks*/ vBlocks,
758760
/*cache_opt*/ TritonGEN::LoadCacheControl::L1C_L3C);
759761
if (failed(newOp.verify())) {
760-
// Explicitly invoke verifier because `triton_gen` ops are
761-
// immediately lowered further to a builtin call.
762+
// delete the op so that the verifier will not abort the pass
763+
// pipeline later, as we can fail this path and try a different
764+
// approach.
765+
rewriter.eraseOp(newOp);
762766
return failure();
763767
}
764768
}
@@ -1573,8 +1577,10 @@ struct LoadOpConversion
15731577
/*transpose*/ false,
15741578
/*vnni_transform*/ false);
15751579
if (failed(load2dOp.verify())) {
1576-
// Explicitly invoke verifier because `triton_gen` ops are
1577-
// immediately lowered further to a builtin call.
1580+
// delete the op so that the verifier will not abort the pass
1581+
// pipeline later, as we can fail this path and try a different
1582+
// approach.
1583+
rewriter.eraseOp(load2dOp);
15781584
return failure();
15791585
}
15801586

@@ -2086,8 +2092,10 @@ struct LoadOpConversion
20862092
(usePackedType && !isOperandA && !isTransposeRequired &&
20872093
originalElemBits != 32));
20882094
if (failed(load2dOp.verify())) {
2089-
// Explicitly invoke verifier because `triton_gen` ops are
2090-
// immediately lowered further to a builtin call.
2095+
// delete the op so that the verifier will not abort the pass
2096+
// pipeline later, as we can fail this path and try a different
2097+
// approach.
2098+
rewriter.eraseOp(load2dOp);
20912099
return failure();
20922100
}
20932101
LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n");
@@ -2520,8 +2528,10 @@ struct StoreOpConversion
25202528
/*stored_val*/ b.bitcast(storeVal, store2DGenXType));
25212529

25222530
if (failed(newOp.verify())) {
2523-
// Explicitly invoke verifier because `triton_gen` ops are
2524-
// immediately lowered further to a builtin call.
2531+
// delete the op so that the verifier will not abort the pass
2532+
// pipeline later, as we can fail this path and try a different
2533+
// approach.
2534+
rewriter.eraseOp(newOp);
25252535
return failure();
25262536
}
25272537
}

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,8 @@ class TritonGPUToLLVMPipelineManager {
254254
// to help convert scalar expression to LLVM.
255255
arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
256256
populateMathToLLVMConversionPatterns(typeConverter, patterns);
257-
triton::populateTritonGENToLLVMConversionPatterns(typeConverter, patterns);
258257
triton::populateGPUToTritonGENConversionPatterns(typeConverter, patterns);
259258
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
260-
populateTritonGENToSPIRVConversionPatterns(patterns);
261259
populateGpuToLLVMSPVConversionPatterns(typeConverter, patterns);
262260
populateSPIRVToLLVMConversionPatterns(typeConverter, patterns,
263261
spirv::ClientAPI::OpenCL);

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
5151
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
5252
: ConversionTarget(ctx) {
5353
addLegalDialect<LLVM::LLVMDialect>();
54-
addIllegalDialect<triton::TritonGEN::TritonGENDialect>();
54+
addLegalDialect<triton::TritonGEN::TritonGENDialect>();
5555
addIllegalDialect<triton::TritonDialect>();
5656
addIllegalDialect<triton::gpu::TritonGPUDialect>();
5757
addIllegalDialect<triton::gpu::intel::TritonIntelGPUDialect>();

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "intel/include/Target/LLVMIR/Dialect/TritonGEN/TritonGENToLLVMIRTranslation.h"
1717
#include "intel/include/Target/LLVMIR/PostProcess.h"
1818
#include "intel/include/TritonAnnotateModule/Passes.h"
19+
#include "intel/include/TritonGENToLLVM/Passes.h"
1920
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
2021
#include "intel/include/TritonRaiseBlockPointer/Passes.h"
2122
#include "intel/include/TritonToTritonGPUWarp/Passes.h"
@@ -67,6 +68,7 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
6768
ADD_PASS_OPTION_WRAPPER_3("add_to_llvmir",
6869
gpu::intel::createConvertTritonIntelGPUToLLVM, bool,
6970
bool, bool);
71+
ADD_PASS_WRAPPER_0("add_gen_to_llvm", createConvertTritonGENToLLVM);
7072
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
7173
gpu::intel::createTritonIntelGPUAccelerateMatmul);
7274
ADD_PASS_WRAPPER_0("add_rewrite_stack_ptr",

0 commit comments

Comments
 (0)