diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h index af8edbbd9b..4078316aca 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h @@ -18,6 +18,11 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, ArrayRef shape); +LinearLayout +subgroup2DBlockToLinearLayout(ArrayRef shape, + intel::Subgroup2DBlockEncodingAttr layout, + unsigned kWidth); + } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONINTELGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 9b0c650448..53cad94064 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -280,4 +280,47 @@ def WarpEncodingAttr : DistributedEncoding<"WarpEncoding", "intel_warp_encoding" let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// Intel Subgroup2DBlock Encoding +//===----------------------------------------------------------------------===// + +def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "subgroup_2d_block_encoding", [MmaEncodingTrait], TritonIntelGPU_Dialect> { + let mnemonic = "subgroup_2d_block"; + + let description = [{ + An encoding for tensors produced via Intel Subgroup 2D Block IO operations. + + The subgroup 2D block IO operations read or write two-dimensional blocks of data from a two-dimensional region of memory. The Subgroup 2D Block Encoding layout is parameterized by the block width, block height, and block count for the individual load instructions and the distribution and replication of loads across warps. + + The SPV_INTEL_2d_block_io extension documentation provides more information on the subgroup 2D block IO operations and parameters: https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html + + For the layout, the following parameters are required: + - `instrShape` : contains the (height, width) block parameters for the block io operation + - `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) + - `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor. + - `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution + - `order` : The order within the block, used to determine along which dimension to broadcast. + - `kWidth` : Currently unused, but keeping because we will likely need it for layout conversions. + - `CTALayout` : Describes how blocks are distributed among work-groups/thread blocks. + }]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape, + "unsigned":$numBlocks, + ArrayRefParameter<"unsigned">:$order, + "unsigned":$kWidth, + "unsigned":$threadsPerWarp + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getRepOrderForOperand(int opIdx) const; + }]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + #endif diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index df30996bcf..1ac47cc81b 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -495,6 +495,173 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const { << "}>"; } +//===----------------------------------------------------------------------===// +// Subgroup2DBlockEncodingAttr +//===----------------------------------------------------------------------===// + +namespace { +std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +// Print the CTALayout if it's not equal to the default. +void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, + CTALayoutAttr layout, unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +} // namespace + +LogicalResult Subgroup2DBlockEncodingAttr::verify( + function_ref emitError, + ArrayRef warpsPerCTA, CTALayoutAttr CTALayout, + ArrayRef instrShape, unsigned numBlocks, ArrayRef order, + unsigned kWidth, unsigned threadsPerWarp) { + if (instrShape.size() != 2) { + return emitError() << "instrShape must be rank 2 but was: " + << instrShape.size(); + } + if (order.size() != 2) { + return emitError() << "order must be rank 2 but was " << order.size(); + } + if (warpsPerCTA.size() != 2) { + return emitError() << "warpsPerCTA must be rank 2 but was " + << warpsPerCTA.size(); + } + if (!(kWidth == 1 || kWidth == 2 || kWidth == 4)) { + return emitError() << "kWidth must be 1, 2 or 4, but was: " << kWidth; + } + if (!threadsPerWarp == 16) { + return emitError() << "threadsPerWarp must be 16, but was: " + << threadsPerWarp; + } + return success(); +} + +Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + unsigned numBlocks = 0; + SmallVector order; + unsigned kWidth = 0; + unsigned threadsPerWarp = 0; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + if (attr.getName() == "numBlocks") { + if (parseUInt(parser, attr, numBlocks, "numBlocks").failed()) + return {}; + } + if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } + if (attr.getName() == "kWidth") { + if (parseUInt(parser, attr, kWidth, "kWidth").failed()) + return {}; + } + if (attr.getName() == "threadsPerWarp") { + if (parseUInt(parser, attr, threadsPerWarp, "threadsPerWarp").failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks, + order, kWidth, threadsPerWarp); +} + +SmallVector Subgroup2DBlockEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector Subgroup2DBlockEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} + +SmallVector Subgroup2DBlockEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} + +SmallVector Subgroup2DBlockEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector +Subgroup2DBlockEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank()); + + printer << ", instrShape = [" << getInstrShape() + << "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder() + << "], kWidth=" << getKWidth() + << ", threadsPerWarp=" << getThreadsPerWarp() << "}>"; +} + +LinearLayout +Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef shape) const { + return subgroup2DBlockToLinearLayout(shape, *this, getKWidth()); +} + //===----------------------------------------------------------------------===// // Dialect Interface //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 0cb8dd540b..64cc423629 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -523,4 +523,119 @@ LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx()); } +namespace { + +static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, + ArrayRef shape, + ArrayRef order, + unsigned broadcastDim, + StringAttr inDimName) { + int rank = shape.size(); + auto dimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::empty(); + + for (auto d : order) { + if (d == broadcastDim) { + layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); + } else { + layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); + } + } + return layout; +} + +using basisT = std::vector>; + +// Creates a row major tile layout with register/lane input dimensions according +// to the provided height, width, and threadsPerWarp. The relationship between +// the width and threadsPerWarp determines the packing of rows across lanes: +// - if width == threadsPerWarp: +// block row elements are mapped to registers in row major order, i.e. one +// column per lane +// - if width < threadsPerWarp: +// multiple rows are mapped to the first register to fill the warp, i.e. +// width * rowsPerWarp = threadsPerWarp +// - if width > threadsPerWarp: +// multiple elements of each row are assigned to registers such that +// packedElementsPerLane row values exist in consecutive registers for each +// lane +std::pair +createRegisterLaneBases(const int height, const int width, + const unsigned threadsPerWarp) { + const int packedElementsPerLane = + mlir::ceil(width, static_cast(threadsPerWarp)); + + basisT laneBases; + for (int i = packedElementsPerLane; i < width; i = i << 1) { + laneBases.push_back({0, i}); + } + + const int rowsPerWarp = + mlir::ceil(threadsPerWarp, 1 << laneBases.size()); + // Place subsequent rows into adjacent lanes until all lanes have been filled + for (int i = 1; i < rowsPerWarp; i = i << 1) { + laneBases.push_back({i, 0}); + } + + basisT regBases; + + // Add packed row-wise elements (width > threadsPerWarp) before adding columns + for (int i = 1; i < packedElementsPerLane; i = i << 1) { + regBases.push_back({0, i}); + } + + for (int i = 1; i < height / rowsPerWarp; i = i << 1) { + regBases.push_back({i * rowsPerWarp, 0}); + } + + return std::make_pair(regBases, laneBases); +} + +} // namespace + +LinearLayout +subgroup2DBlockToLinearLayout(ArrayRef blockShape, + intel::Subgroup2DBlockEncodingAttr layout, + unsigned kWidth) { + auto ctx = layout.getContext(); + int rank = blockShape.size(); + assert(rank == layout.getRank() && "unexpected block shape rank, layout rank " + "and block shape rank must be equal"); + auto dimNames = standardOutDimNames(ctx, rank); + auto loadTileSize = layout.getInstrShape(); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // Start by creating register/lane bases corresponding to the desired load + // tile size + auto [regBases, laneBases] = createRegisterLaneBases( + loadTileSize[0], loadTileSize[1], layout.getThreadsPerWarp()); + + LinearLayout::BasesT bases; + bases[kRegister] = regBases; + bases[kLane] = laneBases; + auto ctaLayout = LinearLayout(bases, dimNames); + + assert(ctaLayout.getInDimSize(kLane) <= layout.getThreadsPerWarp() && + "number of lanes should not exceed threads per warp"); + + // Increasing the block count always increases the inner dimension for the + // register/lane layout regardless of order + ctaLayout *= + LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]); + + // Broadcast the layout according to warpsPerCTA, then combine with the + // overall CTALayout and reshape according to the provided blockShape. + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ true); + auto order = layout.getOrder(); + assert(order.size() == 2 && "only rank 2 order supported"); + unsigned inner = order[0]; + + ctaLayout *= broadcastedDotOperandLayout(ctx, layout.getWarpsPerCTA(), + warpOrder, inner, kWarp) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + return combineCtaCgaWithShape(ctaLayout, layout.getCTALayout(), blockShape); +} + } // namespace mlir::triton::gpu diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 8fe6672188..e80cbadcd2 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1965,6 +1965,20 @@ struct LoadOpConversion } Value elemSizeInBytes = b.i32_val(originalElemBits / 8); + LLVM_DEBUG({ + const unsigned numLoads = numRepOuter * numLoadPerOutRepCluster * + numRepInner / numOperandsInnerDimPerLoad; + llvm::dbgs() << "Preparing to dispatch " << numLoads << " loads\n"; + llvm::dbgs() << "Outer loads: " << numRepOuter * numLoadPerOutRepCluster + << " (" << numLoadPerOutRepCluster + << " per out rep cluster)\n"; + llvm::dbgs() << "Inner loads: " + << numRepInner / numOperandsInnerDimPerLoad << "\n"; + llvm::dbgs() << "Load dimension: " << tileHeight << ", " + << tileWidth * vBlocks << " (" << elemSizeInBits + << " bits)\n"; + }); + ValueTable loadVals; for (int outer = 0; outer < numRepOuter; ++outer) { for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) { diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/CMakeLists.txt b/third_party/intel/unittest/Dialect/TritonIntelGPU/CMakeLists.txt index 38cc176113..1a2b312f1f 100644 --- a/third_party/intel/unittest/Dialect/TritonIntelGPU/CMakeLists.txt +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/CMakeLists.txt @@ -8,3 +8,13 @@ add_triton_ut( TritonIntelGPUTransforms TritonNvidiaGPUTransforms ) +add_triton_ut( + NAME LinearLayoutConversionsIntel + SRCS LinearLayoutConversionsTest.cpp + LIBS + TritonGPUIR + TritonGPUTransforms + TritonIntelAnalysis + TritonIntelGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp new file mode 100644 index 0000000000..caf6fe2a22 --- /dev/null +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp @@ -0,0 +1,170 @@ +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" + +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir::triton::gpu::intel { + +namespace { + +class LinearLayoutConversionsTest : public ::testing::Test { +public: + void SetUp() { ctx.loadDialect(); } + + // Create a Subgroup2DBlockEncoding layout based on a DPAS layout + Subgroup2DBlockEncodingAttr + sdb(ArrayRef instrShape, unsigned numBlocks, unsigned kWidth, + ArrayRef warpsPerCTA, ArrayRef repCluster, + ArrayRef blockShape, unsigned opsPerChannel, unsigned opIdx) { + auto dpasLayout = DpasEncodingAttr::get( + &ctx, /*repeatCount=*/8, /*systolicDepth=*/8, /*executionSize=*/16, + opsPerChannel, warpsPerCTA, repCluster, + /*threadsPerWarp=*/16); + + // TODO: could put the getOrderForDotOperand in the builder? + auto layout = Subgroup2DBlockEncodingAttr::get( + &ctx, warpsPerCTA, + CTALayoutAttr::get( + &ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout? + dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()), + instrShape, numBlocks, + getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth, + dpasLayout.getThreadsPerWarp()); + return layout; + } + + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearLayoutConversionsTest, FP32_32x8x2_M256_N128_K32_A) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*blockShape*/ {256, 32}, + sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*kWidth*/ 4, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 1, /*opIdx*/ 0), + /*kWidth*/ 4), + LinearLayout( + {{S("register"), {{2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}, {0, 16}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}, {64, 0}, {128, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, FP32_32x16x1_M256_N128_K32_B) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*blockShape*/ {32, 128}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 4, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, + /*blockShape*/ {32, 128}, /*opsPerChannel*/ 1, /*opIdx*/ 1), + /*kWidth*/ 4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, FP16_32x32x1_M256_N32_K32_A) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*blockShape*/ {256, 32}, + sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), {{0, 1}, {1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}, {64, 0}, {128, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*blockShape*/ {256, 32}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 16}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}, {64, 0}, {128, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { + EXPECT_EQ(subgroup2DBlockToLinearLayout( + /*shape*/ {32, 256}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, + /*opIdx*/ 1), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, I8_16x32x1_M64_N128_K32_A) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*shape*/ {64, 32}, + sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*kWidth*/ 1, + /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, + /*blockShape*/ {64, 32}, /*opsPerChannel*/ 4, + /*opIdx*/ 0), + /*kWidth*/ 1), + LinearLayout({{S("register"), {{0, 1}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}, {16, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, I8_32x32x1_M64_N128_K32_B) { + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*shape*/ {32, 128}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 1, + /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, + /*blockShape*/ {32, 128}, /*opsPerChannel*/ 4, + /*opIdx*/ 1), + /*kWidth*/ 1), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 64}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +} // anonymous namespace +} // namespace mlir::triton::gpu::intel + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}