diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 2e23d35f52..f1a9838aa2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -19,8 +19,6 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; using namespace mlir::triton::gpu::intel; -#define S(v) StringAttr::get(ctx, (v)) - namespace { Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { @@ -1288,8 +1286,8 @@ struct LoadOpConversion unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]]; MLIRContext *ctx = rewriter.getContext(); - const StringAttr dimOuterStr = S("dim" + std::to_string(dimOuter)); - const StringAttr dimInnerStr = S("dim" + std::to_string(dimInner)); + const StringAttr dimOuterStr = str_attr("dim" + std::to_string(dimOuter)); + const StringAttr dimInnerStr = str_attr("dim" + std::to_string(dimInner)); LLVM_DEBUG({ llvm::dbgs() << "dimOuterStr: " << dimOuterStr << "\n"; llvm::dbgs() << "dimInnerStr: " << dimInnerStr << "\n"; @@ -1308,9 +1306,9 @@ struct LoadOpConversion // the DPAS instruction across all threads/work-items in a sub-group. The // layout will later be expanded to cover multiple DPAS invocations // (iteration) and multiple loads (load). - StringAttr kOffset = S("offset"); - StringAttr kIteration = S("iteration"); - StringAttr kLoad = S("load"); + StringAttr kOffset = str_attr("offset"); + StringAttr kIteration = str_attr("iteration"); + StringAttr kLoad = str_attr("load"); auto createTileLayout = [&](const SmallVectorImpl &threadOrder, SmallVector tileShape) { @@ -1328,7 +1326,7 @@ struct LoadOpConversion for (int i = 0; i < tileShape.size(); i++) { int dim = threadOrder[i]; - StringAttr kOffset = S("offset" + std::to_string(dim)); + StringAttr kOffset = str_attr("offset" + std::to_string(dim)); kOffsetDims.push_back(kOffset); @@ -1355,7 +1353,7 @@ struct LoadOpConversion llvm::dbgs() << "Block load tile layout: " << tileLayout << "\n"; for (size_t i = 0; i < tileLayout.getOutDimSize(dimOuterStr) * tileLayout.getOutDimSize(dimInnerStr); - i += tileLayout.getOutDimSize(S("dim1"))) { + i += tileLayout.getOutDimSize(str_attr("dim1"))) { auto tensorVals = tileLayout.apply({{kOffset, i}}); assert(tensorVals.size() == 2); llvm::dbgs() << i << " : " << tensorVals[0].second << ", " @@ -1499,8 +1497,8 @@ struct LoadOpConversion llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) { outDims.push_back(std::make_pair(name, size)); } - assert(outDims[0].first == S("dim0")); - assert(outDims[1].first == S("dim1")); + assert(outDims[0].first == str_attr("dim0")); + assert(outDims[1].first == str_attr("dim1")); for (size_t i = 0; i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); i++) { @@ -1729,35 +1727,27 @@ struct LoadOpConversion // Save the decomposed vals to the map; switch (opIdx) { case DpasEncodingAttr::OpIdx::OperandA: { + const auto loadX = + outer * numLoadPerOutRepCluster * repOuterStride + + rep * packedRowNum + row; + const auto loadY = k + vblk * packedColNumPerVBlock + col; LLVM_DEBUG({ - llvm::dbgs() << "load vals index: " - << std::to_string(outer * packedRowNum * - numLoadPerOutRepCluster + - rep * packedRowNum + row) - << ", " - << std::to_string( - k + vblk * packedColNumPerVBlock + col) - << "\n"; + llvm::dbgs() << "load vals index: " << loadX << ", " + << loadY << "\n"; }); - loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + - rep * packedRowNum + row, - k + vblk * packedColNumPerVBlock + col}] = + loadVals[{loadX, loadY}] = b.bitcast(loadVal, unpackedDPASOperandType); } break; case DpasEncodingAttr::OpIdx::OperandB: { + const auto loadX = outer * repOuterStride + + rep * packedColNum + + vblk * packedColNumPerVBlock + col; + const auto loadY = k + row; LLVM_DEBUG({ - llvm::dbgs() - << "load vals index: " - << std::to_string(outer * packedColNum * - numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col) - << ", " << std::to_string(k + row) << "\n"; + llvm::dbgs() << "load vals index: " << loadX << ", " + << loadY << "\n"; }); - loadVals[{outer * packedColNum * numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col, - k + row}] = + loadVals[{loadX, loadY}] = b.bitcast(loadVal, unpackedDPASOperandType); } break; case DpasEncodingAttr::OpIdx::OperandC: { @@ -1775,16 +1765,17 @@ struct LoadOpConversion for (int outer = 0; outer < numRepOuter; ++outer) { for (int k = 0; k < numRepInner; ++k) { for (int rep = 0; rep < repCluster[unsigned(opIdx)]; ++rep) { - if (loadVals.find({outer * repCluster[unsigned(opIdx)] + rep, k}) == - loadVals.end()) { + const auto loadValX = (outer * repOuterStride) + rep; + const auto loadValY = k; + + if (loadVals.find({loadValX, loadValY}) == loadVals.end()) { // generate a nice error message before the throw below aborts our // pipeline llvm::errs() << "Failed to find key at " << outer * repCluster[unsigned(opIdx)] + rep << ", " << k << "\n"; } - Value loadVal = - loadVals.at({outer * repCluster[unsigned(opIdx)] + rep, k}); + Value loadVal = loadVals.at({loadValX, loadValY}); VectorType loadTy = cast(loadVal.getType()); for (int i = 0; i < loadTy.getNumElements(); ++i) { auto val = b.extract_element(loadVal, b.i32_val(i));