Skip to content

Use strides when computing load shuffle vector indices in the loadVals map #4065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 28 additions & 37 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ using namespace mlir::triton;
using namespace mlir::triton::gpu;
using namespace mlir::triton::gpu::intel;

#define S(v) StringAttr::get(ctx, (v))

namespace {

Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) {
Expand Down Expand Up @@ -1288,8 +1286,8 @@ struct LoadOpConversion
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];

MLIRContext *ctx = rewriter.getContext();
const StringAttr dimOuterStr = S("dim" + std::to_string(dimOuter));
const StringAttr dimInnerStr = S("dim" + std::to_string(dimInner));
const StringAttr dimOuterStr = str_attr("dim" + std::to_string(dimOuter));
const StringAttr dimInnerStr = str_attr("dim" + std::to_string(dimInner));
LLVM_DEBUG({
llvm::dbgs() << "dimOuterStr: " << dimOuterStr << "\n";
llvm::dbgs() << "dimInnerStr: " << dimInnerStr << "\n";
Expand All @@ -1308,9 +1306,9 @@ struct LoadOpConversion
// the DPAS instruction across all threads/work-items in a sub-group. The
// layout will later be expanded to cover multiple DPAS invocations
// (iteration) and multiple loads (load).
StringAttr kOffset = S("offset");
StringAttr kIteration = S("iteration");
StringAttr kLoad = S("load");
StringAttr kOffset = str_attr("offset");
StringAttr kIteration = str_attr("iteration");
StringAttr kLoad = str_attr("load");

auto createTileLayout = [&](const SmallVectorImpl<unsigned> &threadOrder,
SmallVector<unsigned> tileShape) {
Expand All @@ -1328,7 +1326,7 @@ struct LoadOpConversion

for (int i = 0; i < tileShape.size(); i++) {
int dim = threadOrder[i];
StringAttr kOffset = S("offset" + std::to_string(dim));
StringAttr kOffset = str_attr("offset" + std::to_string(dim));

kOffsetDims.push_back(kOffset);

Expand All @@ -1355,7 +1353,7 @@ struct LoadOpConversion
llvm::dbgs() << "Block load tile layout: " << tileLayout << "\n";
for (size_t i = 0; i < tileLayout.getOutDimSize(dimOuterStr) *
tileLayout.getOutDimSize(dimInnerStr);
i += tileLayout.getOutDimSize(S("dim1"))) {
i += tileLayout.getOutDimSize(str_attr("dim1"))) {
auto tensorVals = tileLayout.apply({{kOffset, i}});
assert(tensorVals.size() == 2);
llvm::dbgs() << i << " : " << tensorVals[0].second << ", "
Expand Down Expand Up @@ -1499,8 +1497,8 @@ struct LoadOpConversion
llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) {
outDims.push_back(std::make_pair(name, size));
}
assert(outDims[0].first == S("dim0"));
assert(outDims[1].first == S("dim1"));
assert(outDims[0].first == str_attr("dim0"));
assert(outDims[1].first == str_attr("dim1"));

for (size_t i = 0;
i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); i++) {
Expand Down Expand Up @@ -1729,35 +1727,27 @@ struct LoadOpConversion
// Save the decomposed vals to the map;
switch (opIdx) {
case DpasEncodingAttr::OpIdx::OperandA: {
const auto loadX =
outer * numLoadPerOutRepCluster * repOuterStride +
rep * packedRowNum + row;
const auto loadY = k + vblk * packedColNumPerVBlock + col;
LLVM_DEBUG({
llvm::dbgs() << "load vals index: "
<< std::to_string(outer * packedRowNum *
numLoadPerOutRepCluster +
rep * packedRowNum + row)
<< ", "
<< std::to_string(
k + vblk * packedColNumPerVBlock + col)
<< "\n";
llvm::dbgs() << "load vals index: " << loadX << ", "
<< loadY << "\n";
});
loadVals[{outer * packedRowNum * numLoadPerOutRepCluster +
rep * packedRowNum + row,
k + vblk * packedColNumPerVBlock + col}] =
loadVals[{loadX, loadY}] =
b.bitcast(loadVal, unpackedDPASOperandType);
} break;
case DpasEncodingAttr::OpIdx::OperandB: {
const auto loadX = outer * repOuterStride +
rep * packedColNum +
vblk * packedColNumPerVBlock + col;
const auto loadY = k + row;
LLVM_DEBUG({
llvm::dbgs()
<< "load vals index: "
<< std::to_string(outer * packedColNum *
numLoadPerOutRepCluster +
rep * packedColNum +
vblk * packedColNumPerVBlock + col)
<< ", " << std::to_string(k + row) << "\n";
llvm::dbgs() << "load vals index: " << loadX << ", "
<< loadY << "\n";
});
loadVals[{outer * packedColNum * numLoadPerOutRepCluster +
rep * packedColNum +
vblk * packedColNumPerVBlock + col,
k + row}] =
loadVals[{loadX, loadY}] =
b.bitcast(loadVal, unpackedDPASOperandType);
} break;
case DpasEncodingAttr::OpIdx::OperandC: {
Expand All @@ -1775,16 +1765,17 @@ struct LoadOpConversion
for (int outer = 0; outer < numRepOuter; ++outer) {
for (int k = 0; k < numRepInner; ++k) {
for (int rep = 0; rep < repCluster[unsigned(opIdx)]; ++rep) {
if (loadVals.find({outer * repCluster[unsigned(opIdx)] + rep, k}) ==
loadVals.end()) {
const auto loadValX = (outer * repOuterStride) + rep;
const auto loadValY = k;

if (loadVals.find({loadValX, loadValY}) == loadVals.end()) {
// generate a nice error message before the throw below aborts our
// pipeline
llvm::errs() << "Failed to find key at "
<< outer * repCluster[unsigned(opIdx)] + rep << ", "
<< k << "\n";
}
Value loadVal =
loadVals.at({outer * repCluster[unsigned(opIdx)] + rep, k});
Value loadVal = loadVals.at({loadValX, loadValY});
VectorType loadTy = cast<VectorType>(loadVal.getType());
for (int i = 0; i < loadTy.getNumElements(); ++i) {
auto val = b.extract_element(loadVal, b.i32_val(i));
Expand Down