diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir index 3212382106..06d08577d0 100644 --- a/test/TritonIntelGPU/coalesce.mlir +++ b/test/TritonIntelGPU/coalesce.mlir @@ -472,6 +472,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n } // ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> @@ -522,3 +523,41 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.target_arch = "spir64", "tt tt.return } } + +// ----- + +// COM: Ensure layout propagation works for a while loop. +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + // CHECK: kernel_make_tensor_descriptor_loop_carried + tt.func public @kernel_make_tensor_descriptor_loop_carried(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) { + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} {order = array} : > + // CHECK: [[ADV_PTR:%.*]] = tt.advance [[PTR]], {{.*}} : > + %4 = tt.make_tensor_ptr %arg0, [%arg1, %arg2], [%arg2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %5 = tt.advance %4, [%c2_i32, %c0_i32] : > + %7 = arith.cmpi slt, %arg1, %arg2 : i64 + // CHECK: scf.while ([[ARG3:%.*]] = [[PTR]], [[ARG4:%.*]] = [[ADV_PTR]]) : (!tt.ptr>, !tt.ptr>) -> (!tt.ptr>, !tt.ptr>) { + %6:2 = scf.while (%arg3 = %4, %arg4 = %5) : (!tt.ptr>, !tt.ptr>) -> (!tt.ptr>, !tt.ptr>) { + // CHECK: scf.condition({{.*}}) [[ARG3]], [[ARG4]] : !tt.ptr>, !tt.ptr> + scf.condition(%7) %arg3, %arg4 : !tt.ptr>, !tt.ptr> + } do { + // CHECK: ^bb0({{.*}}: !tt.ptr>, {{.*}}: !tt.ptr>): + ^bb0(%arg3: !tt.ptr>, %arg4: !tt.ptr>): + // CHECK: [[PTR1:%.*]] = arith.select {{.*}} : !tt.ptr> + // CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : > + // CHECK: [[LOAD:%.*]] = tt.load [[PTR1]] : !tt.ptr> + // CHECK: tt.store [[PTR2]], {{.*}} : !tt.ptr> + // CHECK: scf.yield [[PTR1]], [[PTR2]] : !tt.ptr>, !tt.ptr> + %12 = arith.select %7, %arg4, %arg3 : !tt.ptr> + %13 = tt.advance %12, [%c0_i32, %c2_i32] : > + %15 = tt.load %12 : !tt.ptr> + tt.store %13, %15 : !tt.ptr> + scf.yield %12, %13 : !tt.ptr>, !tt.ptr> + } + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 5241177b59..ad327c23c1 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -2,6 +2,8 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -10,11 +12,11 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/StrUtil.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include #define DEBUG_TYPE "tritonintelgpu-coalesce" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir::triton::gpu::intel { #define GEN_PASS_DEF_TRITONINTELGPUCOALESCE @@ -36,18 +38,18 @@ struct CoalescePass Operation *op, int numWarps, int threadsPerWarp, llvm::MapVector &layoutMap) { Value ptr = getMemAccessPtr(op); - LDBG("ptr: " << ptr); - LDBG("Considering op: " << *op); LLVM_DEBUG({ - DBGS() << "axis info of pointer: "; - axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "[" DEBUG_TYPE "]: Considering op: " << *op << "\n"; + llvm::dbgs().indent(2) << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs().indent(2)); llvm::dbgs() << "\n"; }); const auto &contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); SmallVector order = argSort(contiguity); - LDBG("order=[" << triton::join(order, ", ") << "]"); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "order=[" << tt::join(order, ", ") << "]\n";); RankedTensorType refTensorType = ttgi::getRankedTensorType(ptr.getType()); auto matchesShape = [&refTensorType](const Value &val) { @@ -67,34 +69,39 @@ struct CoalescePass auto currOrder = argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); if (order == currOrder) { - LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "multi-root-slice: insert to memAccessesSameOrder " + << *use << "\n"); memAccessesSameOrder.insert(use); } } } auto shapePerCTA = ttg::getShapePerCTA(refTensorType); - LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); - int numElems = product(shapePerCTA); int numThreads = numWarps * threadsPerWarp; unsigned perThread = ttgi::getNumElementsPerThread(op, order, axisInfoAnalysis); - LDBG("perThread for op: " << perThread); + LLVM_DEBUG({ + llvm::dbgs().indent(2) + << "shapePerCTA=[" << tt::join(shapePerCTA, ", ") << "]\n"; + llvm::dbgs().indent(2) << "perThread for op: " << perThread << "\n"; + }); for (Operation *opSameOrder : memAccessesSameOrder) { if (opSameOrder == op) continue; unsigned currPerThread = ttgi::getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); - LDBG("perThread for opSameOrder: " << currPerThread); + LLVM_DEBUG(llvm::dbgs().indent(2) + << "perThread for opSameOrder: " << currPerThread); perThread = std::max(perThread, currPerThread); } perThread = std::min(perThread, std::max(numElems / numThreads, 1)); - LDBG("perThread: " << perThread); + LLVM_DEBUG(llvm::dbgs().indent(2) << "perThread: " << perThread << "\n"); - if (!dyn_cast(op)) { + if (!dyn_cast(op)) { // For ops that can result in a global memory write, we should enforce // that each thread handles at most 128 bits, which is the widest // available vectorized store op; otherwise, the store will have "gaps" @@ -122,13 +129,23 @@ struct CoalescePass // Find the defining makeTensorPtrOp operation of the given value. static std::optional findDefiningMakeTensorPtrOp(Value val) { - LDBG("Attempting to find `makeTensorPtrOp` defining: " << val); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: \t" + << "Attempting to find `makeTensorPtrOp` defining: " << val + << "\n"; + }); if (auto arg = dyn_cast(val)) { - Operation *parentOp = val.getParentBlock()->getParentOp(); - assert(isa(parentOp) && "Expected a scf::ForOp"); - auto loopArg = - cast(parentOp).getInitArgs()[arg.getArgNumber() - 1]; + Operation *parentOp = arg.getParentBlock()->getParentOp(); + + Value loopArg; + if (auto forOp = dyn_cast(parentOp)) + loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + else if (auto whileOp = dyn_cast(parentOp)) + loopArg = whileOp.getInits()[arg.getArgNumber()]; + else + llvm_unreachable("Unexpected parent operator"); + return findDefiningMakeTensorPtrOp(loopArg); } @@ -142,6 +159,23 @@ struct CoalescePass Value val = forOp.getYieldedValues()[opRes.getResultNumber()]; return findDefiningMakeTensorPtrOp(val); } + if (auto whileOp = dyn_cast(defOp)) { + Value val = whileOp.getYieldedValues()[opRes.getResultNumber()]; + return findDefiningMakeTensorPtrOp(val); + } + if (auto selectOp = dyn_cast(defOp)) { + // Give up if the 2 possible definitions aren't the same. + Value trueVal = selectOp.getTrueValue(), + falseVal = selectOp.getFalseValue(); + std::optional trueDef = + findDefiningMakeTensorPtrOp(trueVal); + std::optional falseDef = + findDefiningMakeTensorPtrOp(falseVal); + if (!trueDef || !falseDef || *trueDef != *falseDef) + return std::nullopt; + return trueDef; + } + assert(false && "unhandled operation"); } @@ -154,6 +188,11 @@ struct CoalescePass if (isa(op)) return false; + // Condition operations trigger updating the layout of the 'after' region in + // the containing while loop, don't skip them. + if (isa(op)) + return false; + // Skip operations that don't yield a result and contain no regions. if (op->getNumResults() == 0 && op->getNumRegions() == 0) return true; @@ -177,6 +216,11 @@ struct CoalescePass assert(op && op->getNumResults() != 0 && "Expecting operation yielding results"); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "ChangeAndPropagateLayout for: "; + op->dumpPretty(); + }); + rewriter.modifyOpInPlace(op, [&]() { for (Value res : op->getResults()) { if (!tt::isTensorPointerType(res.getType())) @@ -188,7 +232,11 @@ struct CoalescePass ptrType.getAddressSpace())); } }); - LDBG("Coalesced op: " << *op); + + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: Coalesced op: "; + op->dumpPretty(); + }); propagateLayout(op, layout, rewriter); } @@ -199,21 +247,47 @@ struct CoalescePass assert(root->getNumResults() != 0 && "Expecting an operation yielding a result"); - LDBG("root: " << *root); + auto mod = root->getParentOfType(); + + LLVM_DEBUG({ + if (!root->getUsers().empty()) { + llvm::dbgs() << "[" DEBUG_TYPE "]: " + << "Propagate layout to operations using: "; + root->dumpPretty(); + } + }); + for (Operation *user : root->getUsers()) { if (filterUser(user)) continue; - LDBG("root's user: " << *user << "\n"); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "user: "; + user->dumpPretty(); + }); + if (auto forOp = dyn_cast(user)) { propagateLayoutToArgsAndBody(forOp, root, layout, rewriter); continue; } + if (auto whileOp = dyn_cast(user)) { + propagateLayoutToArgsAndBody(whileOp, root, layout, rewriter); + continue; + } + if (auto yieldOp = dyn_cast(user)) { if (auto forOp = yieldOp->getParentOfType()) propagateLayoutToLoopResults(forOp, layout, rewriter); + if (auto whileOp = yieldOp->getParentOfType()) + propagateLayoutToLoopResults(whileOp, layout, rewriter); continue; } + + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n"; + mod->dumpPretty(); + }); + changeAndPropagateLayout(user, layout, rewriter); } } @@ -221,44 +295,93 @@ struct CoalescePass // Propagate the layout of the \p arg block argument to its users. void propagateLayout(BlockArgument arg, Attribute layout, IRRewriter &rewriter) const { - LDBG("arg: " << arg); for (Operation *user : arg.getUsers()) { if (filterUser(user)) continue; - LDBG("arg's user: " << *user << "\n"); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "arg's user: "; + user->dumpPretty(); + }); + if (auto yieldOp = dyn_cast(user)) { if (auto forOp = yieldOp->getParentOfType()) propagateLayoutToLoopResults(forOp, layout, rewriter); + if (auto whileOp = yieldOp->getParentOfType()) + propagateLayoutToLoopResults(whileOp, layout, rewriter); + continue; + } + if (auto condOp = dyn_cast(user)) { + if (auto whileOp = condOp->getParentOfType()) { + // Propagate layout to "after" region arguments. + for (auto [condOperand, loopArg] : + llvm::zip(condOp->getOperands().drop_front(), + whileOp.getAfterArguments())) { + if (condOperand != arg || + !tt::isTensorPointerType(condOperand.getType())) + continue; + + // Modify the layout of the loop argument... + tt::PointerType ptrType = cast(loopArg.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + loopArg.setType(tt::PointerType::get(getNewType(tensorType, layout), + ptrType.getAddressSpace())); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "Propagated layout to: "; + loopArg.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << "\n"; + }); + + // ... and then propagate it to the operations in the loop. + propagateLayout(loopArg, layout, rewriter); + } + } continue; } + changeAndPropagateLayout(user, layout, rewriter); } + + LLVM_DEBUG({ + auto mod = + arg.getParentBlock()->getParentOp()->getParentOfType(); + llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n"; + mod->dumpPretty(); + }); } - // Propagate the layout of the \p root operation's result to the \p forOp loop - // init argument that uses it, and transitively to the operations in the loop - // body that use that argument. - void propagateLayoutToArgsAndBody(scf::ForOp forOp, Operation *root, + // Propagate the layout of the \p root operation's result to the \p loopOp + // loop init argument that uses it, and transitively to the operations in the + // loop body that use that argument. + template ::value>> + void propagateLayoutToArgsAndBody(OpType loopOp, Operation *root, Attribute layout, IRRewriter &rewriter) const { assert(llvm::any_of(root->getUsers(), - [&](Operation *user) { return user == forOp; }) && + [&](Operation *user) { return user == loopOp; }) && "Expecting the loop to be a user of the root operation"); - for (BlockArgument arg : forOp.getRegionIterArgs()) { - Value loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + for (BlockArgument arg : loopOp.getRegionIterArgs()) { + Value loopArg; + if constexpr (std::is_same::value) + loopArg = loopOp.getInitArgs()[arg.getArgNumber() - 1]; + if constexpr (std::is_same::value) + loopArg = loopOp.getInits()[arg.getArgNumber()]; + for (OpResult res : root->getResults()) { if (res != loopArg || !tt::isTensorPointerType(res.getType())) continue; - - LDBG("loopArg: " << loopArg); - // Modify the layout of the loop init argument... tt::PointerType ptrType = cast(arg.getType()); auto tensorType = cast(ptrType.getPointeeType()); arg.setType(tt::PointerType::get(getNewType(tensorType, layout), ptrType.getAddressSpace())); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "Propagated layout to: "; + arg.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << "\n"; + }); // ... and then propagate it to the operations in the loop. propagateLayout(arg, layout, rewriter); @@ -266,37 +389,46 @@ struct CoalescePass } } - // Modify the given loop \p forOp and propagate the result of the enclosing - // loop. - void propagateLayoutToLoopResults(scf::ForOp forOp, Attribute layout, + // Modify the given loop \p loopOpt and propagate its results to their users. + template ::value>> + void propagateLayoutToLoopResults(OpType loopOp, Attribute layout, IRRewriter &rewriter) const { - Operation *yieldOp = forOp.getBody()->getTerminator(); - - rewriter.modifyOpInPlace(forOp, [&]() { - for (auto [opType, res] : - llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) { - if (opType == res.getType()) + Operation *yieldOp = nullptr; + if constexpr (std::is_same::value) + yieldOp = loopOp.getBody()->getTerminator(); + if constexpr (std::is_same::value) + yieldOp = loopOp.getYieldOp(); + + rewriter.modifyOpInPlace(loopOp, [&]() { + for (auto [yieldOperandType, res] : + llvm::zip(yieldOp->getOperandTypes(), loopOp.getResults())) { + Type resType = res.getType(); + if (yieldOperandType == resType) continue; - assert(tt::isTensorPointerType(res.getType()) && - tt::isTensorPointerType(opType) && "Expecting blocked pointers"); + assert(tt::isTensorPointerType(resType) && + tt::isTensorPointerType(yieldOperandType) && + "Expecting blocked pointers"); assert(cast( - cast(opType).getPointeeType()) + cast(yieldOperandType).getPointeeType()) .getEncoding() == layout && "Unexpected layout"); - auto resType = cast(res.getType()); + auto ptrType = cast(res.getType()); RankedTensorType tensorType = ttgi::getRankedTensorType(resType); res.setType(tt::PointerType::get(getNewType(tensorType, layout), - resType.getAddressSpace())); + ptrType.getAddressSpace())); } }); - propagateLayout(forOp, layout, rewriter); + propagateLayout(loopOp, layout, rewriter); } void coalesceOp(Attribute encoding, Operation *op) { - LDBG("Coalescing op: " << *op); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "Coalescing op: " << *op << "\n"; + }); OpBuilder builder(op); @@ -309,7 +441,7 @@ struct CoalescePass for (Value operand : op->getOperands()) { auto tensorType = dyn_cast(operand.getType()); if (tensorType && - !isa(tensorType.getEncoding())) { + !isa(tensorType.getEncoding())) { RankedTensorType newType = getNewType(tensorType, encoding); newArgs.push_back(builder.create( op->getLoc(), newType, operand)); @@ -318,7 +450,10 @@ struct CoalescePass "Expecting operand to have blocked pointer type"); auto defOp = findDefiningMakeTensorPtrOp(operand); assert(defOp && "Expected a make_tensor_ptr operation"); - LDBG("Found make_tensor_ptr definition: " << *defOp); + LLVM_DEBUG({ + llvm::dbgs() << "[" DEBUG_TYPE "]: Found definition: " << defOp + << "\n"; + }); IRRewriter rewriter(builder); changeAndPropagateLayout(*defOp, encoding, rewriter); newArgs.push_back(operand); @@ -348,10 +483,7 @@ struct CoalescePass op->getResult(i).replaceAllUsesWith(newResult); } - LDBG("Old op: " << *op); - LDBG("newOp: " << *newOp); op->erase(); - assert(succeeded(verify(newOp)) && "Operation verification failed"); } @@ -380,14 +512,15 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "layoutMap:\n"; + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "layoutMap:\n"; if (layoutMap.empty()) - DBGS() << "\t"; + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "\t"; for (auto [op, encoding] : layoutMap) { - DBGS() << "\top: " << *op << "\n"; - DBGS() << "\tencoding: " << encoding << "\n"; + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "\top: " << *op << "\n"; + llvm::dbgs() << "[" DEBUG_TYPE "]: " << "\tencoding: " << encoding + << "\n"; } - llvm::errs() << "\n"; + llvm::dbgs() << "\n"; }); // For each memory op that has a layout L1: