-
Notifications
You must be signed in to change notification settings - Fork 62
Introduce Subgroup 2D Block Encoding #4193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5a90853
7bbdb01
283e7b2
b381145
e8be3bf
eb1469a
efdbe46
a7c2b5c
697024a
9bb4104
bfef386
55e3e03
2555e37
00957ae
1a20cdc
e69294b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,4 +280,47 @@ def WarpEncodingAttr : DistributedEncoding<"WarpEncoding", "intel_warp_encoding" | |
let hasCustomAssemblyFormat = 1; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Intel Subgroup2DBlock Encoding | ||
//===----------------------------------------------------------------------===// | ||
|
||
def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "subgroup_2d_block_encoding", [MmaEncodingTrait], TritonIntelGPU_Dialect> { | ||
let mnemonic = "subgroup_2d_block"; | ||
|
||
let description = [{ | ||
An encoding for tensors produced via Intel Subgroup 2D Block IO operations. | ||
|
||
The subgroup 2D block IO operations read or write two-dimensional blocks of data from a two-dimensional region of memory. The Subgroup 2D Block Encoding layout is parameterized by the block width, block height, and block count for the individual load instructions and the distribution and replication of loads across warps. | ||
|
||
The SPV_INTEL_2d_block_io extension documentation provides more information on the subgroup 2D block IO operations and parameters: https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html | ||
|
||
For the layout, the following parameters are required: | ||
- `instrShape` : contains the (height, width) block parameters for the block io operation | ||
- `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Better to think of it as two consecutive
You would get different blocks.
vs
Note the interleaving with
No, order is only used when determining the broadcast dimension for the warp/block layout. The tile params always refer to the row-major ordering expected in the subgroup 2d block io instructions. |
||
- `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor. | ||
- `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution | ||
- `order` : The order within the block, used to determine along which dimension to broadcast. | ||
- `kWidth` : Currently unused, but keeping because we will likely need it for layout conversions. | ||
- `CTALayout` : Describes how blocks are distributed among work-groups/thread blocks. | ||
}]; | ||
|
||
let parameters = ( | ||
ins | ||
ArrayRefParameter<"unsigned">:$warpsPerCTA, | ||
"CTALayoutAttr":$CTALayout, | ||
ArrayRefParameter<"unsigned">:$instrShape, | ||
"unsigned":$numBlocks, | ||
ArrayRefParameter<"unsigned">:$order, | ||
"unsigned":$kWidth, | ||
"unsigned":$threadsPerWarp | ||
); | ||
|
||
let extraClassDeclaration = extraDistributedDeclaration # [{ | ||
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const; | ||
}]; | ||
|
||
let hasCustomAssemblyFormat = 1; | ||
let genVerifyDecl = 1; | ||
} | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -523,4 +523,119 @@ LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, | |
return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx()); | ||
} | ||
|
||
namespace { | ||
|
||
static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, | ||
ArrayRef<unsigned> shape, | ||
ArrayRef<unsigned> order, | ||
unsigned broadcastDim, | ||
StringAttr inDimName) { | ||
int rank = shape.size(); | ||
auto dimNames = standardOutDimNames(ctx, rank); | ||
LinearLayout layout = LinearLayout::empty(); | ||
|
||
for (auto d : order) { | ||
if (d == broadcastDim) { | ||
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); | ||
} else { | ||
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); | ||
} | ||
} | ||
return layout; | ||
} | ||
|
||
using basisT = std::vector<std::vector<int32_t>>; | ||
|
||
// Creates a row major tile layout with register/lane input dimensions according | ||
// to the provided height, width, and threadsPerWarp. The relationship between | ||
// the width and threadsPerWarp determines the packing of rows across lanes: | ||
// - if width == threadsPerWarp: | ||
// block row elements are mapped to registers in row major order, i.e. one | ||
// column per lane | ||
// - if width < threadsPerWarp: | ||
// multiple rows are mapped to the first register to fill the warp, i.e. | ||
// width * rowsPerWarp = threadsPerWarp | ||
// - if width > threadsPerWarp: | ||
// multiple elements of each row are assigned to registers such that | ||
// packedElementsPerLane row values exist in consecutive registers for each | ||
// lane | ||
std::pair<basisT, basisT> | ||
createRegisterLaneBases(const int height, const int width, | ||
const unsigned threadsPerWarp) { | ||
const int packedElementsPerLane = | ||
mlir::ceil<int>(width, static_cast<int>(threadsPerWarp)); | ||
|
||
basisT laneBases; | ||
for (int i = packedElementsPerLane; i < width; i = i << 1) { | ||
laneBases.push_back({0, i}); | ||
} | ||
|
||
const int rowsPerWarp = | ||
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size()); | ||
// Place subsequent rows into adjacent lanes until all lanes have been filled | ||
for (int i = 1; i < rowsPerWarp; i = i << 1) { | ||
laneBases.push_back({i, 0}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the idea is that this describes the columns, and if the mapping is one-to-one ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand the question - if |
||
} | ||
|
||
basisT regBases; | ||
|
||
// Add packed row-wise elements (width > threadsPerWarp) before adding columns | ||
for (int i = 1; i < packedElementsPerLane; i = i << 1) { | ||
regBases.push_back({0, i}); | ||
} | ||
|
||
for (int i = 1; i < height / rowsPerWarp; i = i << 1) { | ||
regBases.push_back({i * rowsPerWarp, 0}); | ||
} | ||
|
||
return std::make_pair(regBases, laneBases); | ||
} | ||
|
||
} // namespace | ||
|
||
LinearLayout | ||
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> blockShape, | ||
intel::Subgroup2DBlockEncodingAttr layout, | ||
unsigned kWidth) { | ||
auto ctx = layout.getContext(); | ||
int rank = blockShape.size(); | ||
assert(rank == layout.getRank() && "unexpected block shape rank, layout rank " | ||
"and block shape rank must be equal"); | ||
auto dimNames = standardOutDimNames(ctx, rank); | ||
auto loadTileSize = layout.getInstrShape(); | ||
StringAttr kRegister = S("register"); | ||
StringAttr kLane = S("lane"); | ||
StringAttr kWarp = S("warp"); | ||
|
||
// Start by creating register/lane bases corresponding to the desired load | ||
// tile size | ||
auto [regBases, laneBases] = createRegisterLaneBases( | ||
loadTileSize[0], loadTileSize[1], layout.getThreadsPerWarp()); | ||
|
||
LinearLayout::BasesT bases; | ||
bases[kRegister] = regBases; | ||
bases[kLane] = laneBases; | ||
auto ctaLayout = LinearLayout(bases, dimNames); | ||
|
||
assert(ctaLayout.getInDimSize(kLane) <= layout.getThreadsPerWarp() && | ||
"number of lanes should not exceed threads per warp"); | ||
|
||
// Increasing the block count always increases the inner dimension for the | ||
// register/lane layout regardless of order | ||
ctaLayout *= | ||
LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]); | ||
Comment on lines
+623
to
+626
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite get the inner dimension part. How would the layout be affected in case we have, say, two blocks? I see in tests the block bases are always empty. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The term There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, ok, makes sense now. Here's the correct link btw https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview.html, those seem to be broken |
||
|
||
// Broadcast the layout according to warpsPerCTA, then combine with the | ||
// overall CTALayout and reshape according to the provided blockShape. | ||
auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ true); | ||
auto order = layout.getOrder(); | ||
assert(order.size() == 2 && "only rank 2 order supported"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this essentially restrict the layout rank to be 2 as well? My question should probably be "why does the first assertion not check the layout rank to be 2 in the first place?". Or even is the attribute verification not enough to catch this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - it is essentially being conservative. I don't want to accidentally hit this codepath in any of the dot3d code paths, and it is unclear to me when those dot 3d codepaths are active. For now, I think extra asserts is fine, should help catch issues early in the tests. |
||
unsigned inner = order[0]; | ||
|
||
ctaLayout *= broadcastedDotOperandLayout(ctx, layout.getWarpsPerCTA(), | ||
warpOrder, inner, kWarp) | ||
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); | ||
return combineCtaCgaWithShape(ctaLayout, layout.getCTALayout(), blockShape); | ||
} | ||
|
||
} // namespace mlir::triton::gpu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(TritonGPU) | ||
add_subdirectory(TritonIntelGPU) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
add_triton_ut( | ||
NAME LinearLayoutConversionsIntel | ||
SRCS LinearLayoutConversionsTest.cpp | ||
LIBS | ||
whitneywhtsang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
TritonGPUIR | ||
TritonGPUTransforms | ||
TritonIntelAnalysis | ||
TritonIntelGPUTransforms | ||
TritonNvidiaGPUTransforms | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this coming from the ttg dialect somehow? Or is it just a copy paste error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The NVIDIA lib is a dependency for
|
||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
kWidth
is unused. It's also available as the op parameter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is unused, but it will probably be used later so I left it for now :)