-
Notifications
You must be signed in to change notification settings - Fork 62
Introduce Subgroup 2D Block Encoding #4193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
TritonGPUTransforms | ||
TritonIntelAnalysis | ||
TritonIntelGPUTransforms | ||
TritonNvidiaGPUTransforms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this coming from the ttg dialect somehow? Or is it just a copy paste error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NVIDIA lib is a dependency for TritonGPUIR
and TritonGPUTransforms
.
/usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/LowerLoops.cpp.o: in function `mlir::triton::gpu::(anonymous namespace)::mustLoadToRegisters(mlir::Operation*)':
/home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp:162: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
/usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp:165: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
/usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/LowerLoops.cpp.o: in function `llvm::LogicalResult mlir::triton::nvidia_gpu::createTMADesc<mlir::triton::gpu::(anonymous namespace)::OpBuilderForStage>(mlir::Value, mlir::triton::MakeTensorDescOp, mlir::triton::gpu::(anonymous namespace)::OpBuilderForStage&)':
/home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:77: undefined reference to `mlir::triton::nvidia_gpu::getTMAContigDim(mlir::Attribute, llvm::ArrayRef<long>)'
/usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:99: undefined reference to `mlir::triton::nvidia_gpu::getTMASwizzleMode(mlir::Operation*, mlir::triton::TensorDescType)'
/usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:125: undefined reference to `mlir::triton::nvidia_gpu::getTMAElementType(mlir::Operation*, mlir::triton::TensorDescType)'
/usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/PipeliningUtility.cpp.o: in function `mlir::triton::getSharedEncoding(mlir::Operation*)':
/home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp:513: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
collect2: error: ld returned 1 exit status
// * if width == threadsPerWarp, registers are distributed to lanes in row major | ||
// order, i.e. one column per lane | ||
// * if width < threadsPerWarp, rows are distributed to lanes according to | ||
// rowToLaneration, i.e. width * rowToLaneRatio = threadsPerWarp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// rowToLaneration, i.e. width * rowToLaneRatio = threadsPerWarp | |
// rowToLaneRatio, i.e. width * rowToLaneRatio = threadsPerWarp |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed and cleaned up the variables and comments in this function.
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size()); | ||
// Place subsequent rows into adjacent lanes until all lanes have been filled | ||
for (int i = 1; i < rowsToLaneRatio; i = i << 1) { | ||
laneBases.push_back({i, 0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the idea is that this describes the columns, and if the mapping is one-to-one (width == threadsPerWarp
), then these values are just absent (or err, broadcasted into a single slot)? Is it fair to say in the case of width == threadsPerWarp
you have a lower dimensionality here? Or in other words, if I were to set the value anyway, would it just be the one in {0,0}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand the question - if rowsToLaneRatio
(now renamed rowsPerWarp
) is 1 then we only have one row per set of registers per lane - this means that each lane has one or more columns of elements (depending on if the width of the block is equal to or larger than the number of threads in the warp / number of lanes). Regardless, because each lane is processing a column there is never a situation where the row index is non-zero if the register index is zero - i.e. the first row of registers will always have row index zero.
// overall CTALayout and reshape according to the provided blockShape. | ||
auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ true); | ||
auto order = layout.getOrder(); | ||
assert(order.size() == 2 && "only rank 2 order supported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this essentially restrict the layout rank to be 2 as well? My question should probably be "why does the first assertion not check the layout rank to be 2 in the first place?". Or even is the attribute verification not enough to catch this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - it is essentially being conservative. I don't want to accidentally hit this codepath in any of the dot3d code paths, and it is unclear to me when those dot 3d codepaths are active. For now, I think extra asserts is fine, should help catch issues early in the tests.
const unsigned packedElementsPerLane = | ||
mlir::ceil<unsigned>(width, threadsPerWarp); | ||
|
||
basisT laneBases; | ||
for (int i = packedElementsPerLane; i < width; i = i << 1) { | ||
laneBases.push_back({0, i}); | ||
} | ||
|
||
const int rowsToLaneRatio = | ||
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this makes me think the types matter somehow since they are different whereas they should be all unsigned it seems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I just realized the bases are int
, nvm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the bases are int
, but I think that means we should uniformly use ints (or at least as much as possible) - so I changed everything except the input threadsPerWarp
param to int.
// Increasing the block count always increases the inner dimension for the | ||
// register/lane layout regardless of order | ||
ctaLayout *= | ||
LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get the inner dimension part. How would the layout be affected in case we have, say, two blocks? I see in tests the block bases are always empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The term Blocks
is used twice - here we refer to block count which is a parameter of the block IO operation - https://github.khronos.org/SPIRV-Registry/extensions/INTEL/[SPV_INTEL_2d_block_io](https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview).html#_overview . Block count is essentially used to load consecutive blocks to better utilize the cache lines. Tile blocks are loaded consecutively in row-major order, so we simply multiply by the identity layout on the inner dimension, which effectively "copies" the ctaLayout numBlocks
times w/r/t the inner dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok, makes sense now. Here's the correct link btw https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview.html, those seem to be broken
|
||
For the layout, the following parameters are required: | ||
- `instrShape` : contains the (height, width) block parameters for the block io operation | ||
- `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the instrShape=[16, 16]
and numBlocks=2
, then tensor is going to be split into [16, 32]
blocks, right? Can I specify [32, 16]
or [32, 32]
block using the same instrShape
? Does order
affect this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the instrShape=[16, 16] and numBlocks=2, then tensor is going to be split into [16, 32] blocks, right?
Better to think of it as two consecutive [16, 16]
blocks in row-major order.
Can I specify [32, 16] or [32, 32] block using the same instrShape?
You would get different blocks.
32x16x2
:
Print layout attribute: #triton_intel_gpu.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks=2, order=[1, 0], kWidth=2, threadsPerWarp=16}>
Warp0:
( 0, 0), ( 0, 1), ( 0, 2), ( 0, 3), ( 0, 4), ( 0, 5), ( 0, 6), ( 0, 7), ( 0, 8), ( 0, 9), ( 0,10), ( 0,11), ( 0,12), ( 0,13), ( 0,14), ( 0,15)
( 1, 0), ( 1, 1), ( 1, 2), ( 1, 3), ( 1, 4), ( 1, 5), ( 1, 6), ( 1, 7), ( 1, 8), ( 1, 9), ( 1,10), ( 1,11), ( 1,12), ( 1,13), ( 1,14), ( 1,15)
( 2, 0), ( 2, 1), ( 2, 2), ( 2, 3), ( 2, 4), ( 2, 5), ( 2, 6), ( 2, 7), ( 2, 8), ( 2, 9), ( 2,10), ( 2,11), ( 2,12), ( 2,13), ( 2,14), ( 2,15)
( 3, 0), ( 3, 1), ( 3, 2), ( 3, 3), ( 3, 4), ( 3, 5), ( 3, 6), ( 3, 7), ( 3, 8), ( 3, 9), ( 3,10), ( 3,11), ( 3,12), ( 3,13), ( 3,14), ( 3,15)
( 4, 0), ( 4, 1), ( 4, 2), ( 4, 3), ( 4, 4), ( 4, 5), ( 4, 6), ( 4, 7), ( 4, 8), ( 4, 9), ( 4,10), ( 4,11), ( 4,12), ( 4,13), ( 4,14), ( 4,15)
( 5, 0), ( 5, 1), ( 5, 2), ( 5, 3), ( 5, 4), ( 5, 5), ( 5, 6), ( 5, 7), ( 5, 8), ( 5, 9), ( 5,10), ( 5,11), ( 5,12), ( 5,13), ( 5,14), ( 5,15)
( 6, 0), ( 6, 1), ( 6, 2), ( 6, 3), ( 6, 4), ( 6, 5), ( 6, 6), ( 6, 7), ( 6, 8), ( 6, 9), ( 6,10), ( 6,11), ( 6,12), ( 6,13), ( 6,14), ( 6,15)
( 7, 0), ( 7, 1), ( 7, 2), ( 7, 3), ( 7, 4), ( 7, 5), ( 7, 6), ( 7, 7), ( 7, 8), ( 7, 9), ( 7,10), ( 7,11), ( 7,12), ( 7,13), ( 7,14), ( 7,15)
( 8, 0), ( 8, 1), ( 8, 2), ( 8, 3), ( 8, 4), ( 8, 5), ( 8, 6), ( 8, 7), ( 8, 8), ( 8, 9), ( 8,10), ( 8,11), ( 8,12), ( 8,13), ( 8,14), ( 8,15)
( 9, 0), ( 9, 1), ( 9, 2), ( 9, 3), ( 9, 4), ( 9, 5), ( 9, 6), ( 9, 7), ( 9, 8), ( 9, 9), ( 9,10), ( 9,11), ( 9,12), ( 9,13), ( 9,14), ( 9,15)
( 10, 0), ( 10, 1), ( 10, 2), ( 10, 3), ( 10, 4), ( 10, 5), ( 10, 6), ( 10, 7), ( 10, 8), ( 10, 9), ( 10,10), ( 10,11), ( 10,12), ( 10,13), ( 10,14), ( 10,15)
( 11, 0), ( 11, 1), ( 11, 2), ( 11, 3), ( 11, 4), ( 11, 5), ( 11, 6), ( 11, 7), ( 11, 8), ( 11, 9), ( 11,10), ( 11,11), ( 11,12), ( 11,13), ( 11,14), ( 11,15)
( 12, 0), ( 12, 1), ( 12, 2), ( 12, 3), ( 12, 4), ( 12, 5), ( 12, 6), ( 12, 7), ( 12, 8), ( 12, 9), ( 12,10), ( 12,11), ( 12,12), ( 12,13), ( 12,14), ( 12,15)
( 13, 0), ( 13, 1), ( 13, 2), ( 13, 3), ( 13, 4), ( 13, 5), ( 13, 6), ( 13, 7), ( 13, 8), ( 13, 9), ( 13,10), ( 13,11), ( 13,12), ( 13,13), ( 13,14), ( 13,15)
( 14, 0), ( 14, 1), ( 14, 2), ( 14, 3), ( 14, 4), ( 14, 5), ( 14, 6), ( 14, 7), ( 14, 8), ( 14, 9), ( 14,10), ( 14,11), ( 14,12), ( 14,13), ( 14,14), ( 14,15)
( 15, 0), ( 15, 1), ( 15, 2), ( 15, 3), ( 15, 4), ( 15, 5), ( 15, 6), ( 15, 7), ( 15, 8), ( 15, 9), ( 15,10), ( 15,11), ( 15,12), ( 15,13), ( 15,14), ( 15,15)
( 16, 0), ( 16, 1), ( 16, 2), ( 16, 3), ( 16, 4), ( 16, 5), ( 16, 6), ( 16, 7), ( 16, 8), ( 16, 9), ( 16,10), ( 16,11), ( 16,12), ( 16,13), ( 16,14), ( 16,15)
( 17, 0), ( 17, 1), ( 17, 2), ( 17, 3), ( 17, 4), ( 17, 5), ( 17, 6), ( 17, 7), ( 17, 8), ( 17, 9), ( 17,10), ( 17,11), ( 17,12), ( 17,13), ( 17,14), ( 17,15)
( 18, 0), ( 18, 1), ( 18, 2), ( 18, 3), ( 18, 4), ( 18, 5), ( 18, 6), ( 18, 7), ( 18, 8), ( 18, 9), ( 18,10), ( 18,11), ( 18,12), ( 18,13), ( 18,14), ( 18,15)
( 19, 0), ( 19, 1), ( 19, 2), ( 19, 3), ( 19, 4), ( 19, 5), ( 19, 6), ( 19, 7), ( 19, 8), ( 19, 9), ( 19,10), ( 19,11), ( 19,12), ( 19,13), ( 19,14), ( 19,15)
( 20, 0), ( 20, 1), ( 20, 2), ( 20, 3), ( 20, 4), ( 20, 5), ( 20, 6), ( 20, 7), ( 20, 8), ( 20, 9), ( 20,10), ( 20,11), ( 20,12), ( 20,13), ( 20,14), ( 20,15)
( 21, 0), ( 21, 1), ( 21, 2), ( 21, 3), ( 21, 4), ( 21, 5), ( 21, 6), ( 21, 7), ( 21, 8), ( 21, 9), ( 21,10), ( 21,11), ( 21,12), ( 21,13), ( 21,14), ( 21,15)
( 22, 0), ( 22, 1), ( 22, 2), ( 22, 3), ( 22, 4), ( 22, 5), ( 22, 6), ( 22, 7), ( 22, 8), ( 22, 9), ( 22,10), ( 22,11), ( 22,12), ( 22,13), ( 22,14), ( 22,15)
( 23, 0), ( 23, 1), ( 23, 2), ( 23, 3), ( 23, 4), ( 23, 5), ( 23, 6), ( 23, 7), ( 23, 8), ( 23, 9), ( 23,10), ( 23,11), ( 23,12), ( 23,13), ( 23,14), ( 23,15)
( 24, 0), ( 24, 1), ( 24, 2), ( 24, 3), ( 24, 4), ( 24, 5), ( 24, 6), ( 24, 7), ( 24, 8), ( 24, 9), ( 24,10), ( 24,11), ( 24,12), ( 24,13), ( 24,14), ( 24,15)
( 25, 0), ( 25, 1), ( 25, 2), ( 25, 3), ( 25, 4), ( 25, 5), ( 25, 6), ( 25, 7), ( 25, 8), ( 25, 9), ( 25,10), ( 25,11), ( 25,12), ( 25,13), ( 25,14), ( 25,15)
( 26, 0), ( 26, 1), ( 26, 2), ( 26, 3), ( 26, 4), ( 26, 5), ( 26, 6), ( 26, 7), ( 26, 8), ( 26, 9), ( 26,10), ( 26,11), ( 26,12), ( 26,13), ( 26,14), ( 26,15)
( 27, 0), ( 27, 1), ( 27, 2), ( 27, 3), ( 27, 4), ( 27, 5), ( 27, 6), ( 27, 7), ( 27, 8), ( 27, 9), ( 27,10), ( 27,11), ( 27,12), ( 27,13), ( 27,14), ( 27,15)
( 28, 0), ( 28, 1), ( 28, 2), ( 28, 3), ( 28, 4), ( 28, 5), ( 28, 6), ( 28, 7), ( 28, 8), ( 28, 9), ( 28,10), ( 28,11), ( 28,12), ( 28,13), ( 28,14), ( 28,15)
( 29, 0), ( 29, 1), ( 29, 2), ( 29, 3), ( 29, 4), ( 29, 5), ( 29, 6), ( 29, 7), ( 29, 8), ( 29, 9), ( 29,10), ( 29,11), ( 29,12), ( 29,13), ( 29,14), ( 29,15)
( 30, 0), ( 30, 1), ( 30, 2), ( 30, 3), ( 30, 4), ( 30, 5), ( 30, 6), ( 30, 7), ( 30, 8), ( 30, 9), ( 30,10), ( 30,11), ( 30,12), ( 30,13), ( 30,14), ( 30,15)
( 31, 0), ( 31, 1), ( 31, 2), ( 31, 3), ( 31, 4), ( 31, 5), ( 31, 6), ( 31, 7), ( 31, 8), ( 31, 9), ( 31,10), ( 31,11), ( 31,12), ( 31,13), ( 31,14), ( 31,15)
( 0,16), ( 0,17), ( 0,18), ( 0,19), ( 0,20), ( 0,21), ( 0,22), ( 0,23), ( 0,24), ( 0,25), ( 0,26), ( 0,27), ( 0,28), ( 0,29), ( 0,30), ( 0,31)
( 1,16), ( 1,17), ( 1,18), ( 1,19), ( 1,20), ( 1,21), ( 1,22), ( 1,23), ( 1,24), ( 1,25), ( 1,26), ( 1,27), ( 1,28), ( 1,29), ( 1,30), ( 1,31)
( 2,16), ( 2,17), ( 2,18), ( 2,19), ( 2,20), ( 2,21), ( 2,22), ( 2,23), ( 2,24), ( 2,25), ( 2,26), ( 2,27), ( 2,28), ( 2,29), ( 2,30), ( 2,31)
( 3,16), ( 3,17), ( 3,18), ( 3,19), ( 3,20), ( 3,21), ( 3,22), ( 3,23), ( 3,24), ( 3,25), ( 3,26), ( 3,27), ( 3,28), ( 3,29), ( 3,30), ( 3,31)
( 4,16), ( 4,17), ( 4,18), ( 4,19), ( 4,20), ( 4,21), ( 4,22), ( 4,23), ( 4,24), ( 4,25), ( 4,26), ( 4,27), ( 4,28), ( 4,29), ( 4,30), ( 4,31)
( 5,16), ( 5,17), ( 5,18), ( 5,19), ( 5,20), ( 5,21), ( 5,22), ( 5,23), ( 5,24), ( 5,25), ( 5,26), ( 5,27), ( 5,28), ( 5,29), ( 5,30), ( 5,31)
( 6,16), ( 6,17), ( 6,18), ( 6,19), ( 6,20), ( 6,21), ( 6,22), ( 6,23), ( 6,24), ( 6,25), ( 6,26), ( 6,27), ( 6,28), ( 6,29), ( 6,30), ( 6,31)
( 7,16), ( 7,17), ( 7,18), ( 7,19), ( 7,20), ( 7,21), ( 7,22), ( 7,23), ( 7,24), ( 7,25), ( 7,26), ( 7,27), ( 7,28), ( 7,29), ( 7,30), ( 7,31)
( 8,16), ( 8,17), ( 8,18), ( 8,19), ( 8,20), ( 8,21), ( 8,22), ( 8,23), ( 8,24), ( 8,25), ( 8,26), ( 8,27), ( 8,28), ( 8,29), ( 8,30), ( 8,31)
( 9,16), ( 9,17), ( 9,18), ( 9,19), ( 9,20), ( 9,21), ( 9,22), ( 9,23), ( 9,24), ( 9,25), ( 9,26), ( 9,27), ( 9,28), ( 9,29), ( 9,30), ( 9,31)
( 10,16), ( 10,17), ( 10,18), ( 10,19), ( 10,20), ( 10,21), ( 10,22), ( 10,23), ( 10,24), ( 10,25), ( 10,26), ( 10,27), ( 10,28), ( 10,29), ( 10,30), ( 10,31)
( 11,16), ( 11,17), ( 11,18), ( 11,19), ( 11,20), ( 11,21), ( 11,22), ( 11,23), ( 11,24), ( 11,25), ( 11,26), ( 11,27), ( 11,28), ( 11,29), ( 11,30), ( 11,31)
( 12,16), ( 12,17), ( 12,18), ( 12,19), ( 12,20), ( 12,21), ( 12,22), ( 12,23), ( 12,24), ( 12,25), ( 12,26), ( 12,27), ( 12,28), ( 12,29), ( 12,30), ( 12,31)
( 13,16), ( 13,17), ( 13,18), ( 13,19), ( 13,20), ( 13,21), ( 13,22), ( 13,23), ( 13,24), ( 13,25), ( 13,26), ( 13,27), ( 13,28), ( 13,29), ( 13,30), ( 13,31)
( 14,16), ( 14,17), ( 14,18), ( 14,19), ( 14,20), ( 14,21), ( 14,22), ( 14,23), ( 14,24), ( 14,25), ( 14,26), ( 14,27), ( 14,28), ( 14,29), ( 14,30), ( 14,31)
( 15,16), ( 15,17), ( 15,18), ( 15,19), ( 15,20), ( 15,21), ( 15,22), ( 15,23), ( 15,24), ( 15,25), ( 15,26), ( 15,27), ( 15,28), ( 15,29), ( 15,30), ( 15,31)
( 16,16), ( 16,17), ( 16,18), ( 16,19), ( 16,20), ( 16,21), ( 16,22), ( 16,23), ( 16,24), ( 16,25), ( 16,26), ( 16,27), ( 16,28), ( 16,29), ( 16,30), ( 16,31)
( 17,16), ( 17,17), ( 17,18), ( 17,19), ( 17,20), ( 17,21), ( 17,22), ( 17,23), ( 17,24), ( 17,25), ( 17,26), ( 17,27), ( 17,28), ( 17,29), ( 17,30), ( 17,31)
( 18,16), ( 18,17), ( 18,18), ( 18,19), ( 18,20), ( 18,21), ( 18,22), ( 18,23), ( 18,24), ( 18,25), ( 18,26), ( 18,27), ( 18,28), ( 18,29), ( 18,30), ( 18,31)
( 19,16), ( 19,17), ( 19,18), ( 19,19), ( 19,20), ( 19,21), ( 19,22), ( 19,23), ( 19,24), ( 19,25), ( 19,26), ( 19,27), ( 19,28), ( 19,29), ( 19,30), ( 19,31)
( 20,16), ( 20,17), ( 20,18), ( 20,19), ( 20,20), ( 20,21), ( 20,22), ( 20,23), ( 20,24), ( 20,25), ( 20,26), ( 20,27), ( 20,28), ( 20,29), ( 20,30), ( 20,31)
( 21,16), ( 21,17), ( 21,18), ( 21,19), ( 21,20), ( 21,21), ( 21,22), ( 21,23), ( 21,24), ( 21,25), ( 21,26), ( 21,27), ( 21,28), ( 21,29), ( 21,30), ( 21,31)
( 22,16), ( 22,17), ( 22,18), ( 22,19), ( 22,20), ( 22,21), ( 22,22), ( 22,23), ( 22,24), ( 22,25), ( 22,26), ( 22,27), ( 22,28), ( 22,29), ( 22,30), ( 22,31)
( 23,16), ( 23,17), ( 23,18), ( 23,19), ( 23,20), ( 23,21), ( 23,22), ( 23,23), ( 23,24), ( 23,25), ( 23,26), ( 23,27), ( 23,28), ( 23,29), ( 23,30), ( 23,31)
( 24,16), ( 24,17), ( 24,18), ( 24,19), ( 24,20), ( 24,21), ( 24,22), ( 24,23), ( 24,24), ( 24,25), ( 24,26), ( 24,27), ( 24,28), ( 24,29), ( 24,30), ( 24,31)
( 25,16), ( 25,17), ( 25,18), ( 25,19), ( 25,20), ( 25,21), ( 25,22), ( 25,23), ( 25,24), ( 25,25), ( 25,26), ( 25,27), ( 25,28), ( 25,29), ( 25,30), ( 25,31)
( 26,16), ( 26,17), ( 26,18), ( 26,19), ( 26,20), ( 26,21), ( 26,22), ( 26,23), ( 26,24), ( 26,25), ( 26,26), ( 26,27), ( 26,28), ( 26,29), ( 26,30), ( 26,31)
( 27,16), ( 27,17), ( 27,18), ( 27,19), ( 27,20), ( 27,21), ( 27,22), ( 27,23), ( 27,24), ( 27,25), ( 27,26), ( 27,27), ( 27,28), ( 27,29), ( 27,30), ( 27,31)
( 28,16), ( 28,17), ( 28,18), ( 28,19), ( 28,20), ( 28,21), ( 28,22), ( 28,23), ( 28,24), ( 28,25), ( 28,26), ( 28,27), ( 28,28), ( 28,29), ( 28,30), ( 28,31)
( 29,16), ( 29,17), ( 29,18), ( 29,19), ( 29,20), ( 29,21), ( 29,22), ( 29,23), ( 29,24), ( 29,25), ( 29,26), ( 29,27), ( 29,28), ( 29,29), ( 29,30), ( 29,31)
( 30,16), ( 30,17), ( 30,18), ( 30,19), ( 30,20), ( 30,21), ( 30,22), ( 30,23), ( 30,24), ( 30,25), ( 30,26), ( 30,27), ( 30,28), ( 30,29), ( 30,30), ( 30,31)
( 31,16), ( 31,17), ( 31,18), ( 31,19), ( 31,20), ( 31,21), ( 31,22), ( 31,23), ( 31,24), ( 31,25), ( 31,26), ( 31,27), ( 31,28), ( 31,29), ( 31,30), ( 31,31)
vs 32x32x1
Print layout attribute: #triton_intel_gpu.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 32], numBlocks=1, order=[1, 0], kWidth=2, threadsPerWarp=16}>
Warp0:
( 0, 0), ( 0, 2), ( 0, 4), ( 0, 6), ( 0, 8), ( 0,10), ( 0,12), ( 0,14), ( 0,16), ( 0,18), ( 0,20), ( 0,22), ( 0,24), ( 0,26), ( 0,28), ( 0,30)
( 0, 1), ( 0, 3), ( 0, 5), ( 0, 7), ( 0, 9), ( 0,11), ( 0,13), ( 0,15), ( 0,17), ( 0,19), ( 0,21), ( 0,23), ( 0,25), ( 0,27), ( 0,29), ( 0,31)
( 1, 0), ( 1, 2), ( 1, 4), ( 1, 6), ( 1, 8), ( 1,10), ( 1,12), ( 1,14), ( 1,16), ( 1,18), ( 1,20), ( 1,22), ( 1,24), ( 1,26), ( 1,28), ( 1,30)
( 1, 1), ( 1, 3), ( 1, 5), ( 1, 7), ( 1, 9), ( 1,11), ( 1,13), ( 1,15), ( 1,17), ( 1,19), ( 1,21), ( 1,23), ( 1,25), ( 1,27), ( 1,29), ( 1,31)
( 2, 0), ( 2, 2), ( 2, 4), ( 2, 6), ( 2, 8), ( 2,10), ( 2,12), ( 2,14), ( 2,16), ( 2,18), ( 2,20), ( 2,22), ( 2,24), ( 2,26), ( 2,28), ( 2,30)
( 2, 1), ( 2, 3), ( 2, 5), ( 2, 7), ( 2, 9), ( 2,11), ( 2,13), ( 2,15), ( 2,17), ( 2,19), ( 2,21), ( 2,23), ( 2,25), ( 2,27), ( 2,29), ( 2,31)
( 3, 0), ( 3, 2), ( 3, 4), ( 3, 6), ( 3, 8), ( 3,10), ( 3,12), ( 3,14), ( 3,16), ( 3,18), ( 3,20), ( 3,22), ( 3,24), ( 3,26), ( 3,28), ( 3,30)
( 3, 1), ( 3, 3), ( 3, 5), ( 3, 7), ( 3, 9), ( 3,11), ( 3,13), ( 3,15), ( 3,17), ( 3,19), ( 3,21), ( 3,23), ( 3,25), ( 3,27), ( 3,29), ( 3,31)
( 4, 0), ( 4, 2), ( 4, 4), ( 4, 6), ( 4, 8), ( 4,10), ( 4,12), ( 4,14), ( 4,16), ( 4,18), ( 4,20), ( 4,22), ( 4,24), ( 4,26), ( 4,28), ( 4,30)
( 4, 1), ( 4, 3), ( 4, 5), ( 4, 7), ( 4, 9), ( 4,11), ( 4,13), ( 4,15), ( 4,17), ( 4,19), ( 4,21), ( 4,23), ( 4,25), ( 4,27), ( 4,29), ( 4,31)
( 5, 0), ( 5, 2), ( 5, 4), ( 5, 6), ( 5, 8), ( 5,10), ( 5,12), ( 5,14), ( 5,16), ( 5,18), ( 5,20), ( 5,22), ( 5,24), ( 5,26), ( 5,28), ( 5,30)
( 5, 1), ( 5, 3), ( 5, 5), ( 5, 7), ( 5, 9), ( 5,11), ( 5,13), ( 5,15), ( 5,17), ( 5,19), ( 5,21), ( 5,23), ( 5,25), ( 5,27), ( 5,29), ( 5,31)
( 6, 0), ( 6, 2), ( 6, 4), ( 6, 6), ( 6, 8), ( 6,10), ( 6,12), ( 6,14), ( 6,16), ( 6,18), ( 6,20), ( 6,22), ( 6,24), ( 6,26), ( 6,28), ( 6,30)
( 6, 1), ( 6, 3), ( 6, 5), ( 6, 7), ( 6, 9), ( 6,11), ( 6,13), ( 6,15), ( 6,17), ( 6,19), ( 6,21), ( 6,23), ( 6,25), ( 6,27), ( 6,29), ( 6,31)
( 7, 0), ( 7, 2), ( 7, 4), ( 7, 6), ( 7, 8), ( 7,10), ( 7,12), ( 7,14), ( 7,16), ( 7,18), ( 7,20), ( 7,22), ( 7,24), ( 7,26), ( 7,28), ( 7,30)
( 7, 1), ( 7, 3), ( 7, 5), ( 7, 7), ( 7, 9), ( 7,11), ( 7,13), ( 7,15), ( 7,17), ( 7,19), ( 7,21), ( 7,23), ( 7,25), ( 7,27), ( 7,29), ( 7,31)
( 8, 0), ( 8, 2), ( 8, 4), ( 8, 6), ( 8, 8), ( 8,10), ( 8,12), ( 8,14), ( 8,16), ( 8,18), ( 8,20), ( 8,22), ( 8,24), ( 8,26), ( 8,28), ( 8,30)
( 8, 1), ( 8, 3), ( 8, 5), ( 8, 7), ( 8, 9), ( 8,11), ( 8,13), ( 8,15), ( 8,17), ( 8,19), ( 8,21), ( 8,23), ( 8,25), ( 8,27), ( 8,29), ( 8,31)
( 9, 0), ( 9, 2), ( 9, 4), ( 9, 6), ( 9, 8), ( 9,10), ( 9,12), ( 9,14), ( 9,16), ( 9,18), ( 9,20), ( 9,22), ( 9,24), ( 9,26), ( 9,28), ( 9,30)
( 9, 1), ( 9, 3), ( 9, 5), ( 9, 7), ( 9, 9), ( 9,11), ( 9,13), ( 9,15), ( 9,17), ( 9,19), ( 9,21), ( 9,23), ( 9,25), ( 9,27), ( 9,29), ( 9,31)
( 10, 0), ( 10, 2), ( 10, 4), ( 10, 6), ( 10, 8), ( 10,10), ( 10,12), ( 10,14), ( 10,16), ( 10,18), ( 10,20), ( 10,22), ( 10,24), ( 10,26), ( 10,28), ( 10,30)
( 10, 1), ( 10, 3), ( 10, 5), ( 10, 7), ( 10, 9), ( 10,11), ( 10,13), ( 10,15), ( 10,17), ( 10,19), ( 10,21), ( 10,23), ( 10,25), ( 10,27), ( 10,29), ( 10,31)
( 11, 0), ( 11, 2), ( 11, 4), ( 11, 6), ( 11, 8), ( 11,10), ( 11,12), ( 11,14), ( 11,16), ( 11,18), ( 11,20), ( 11,22), ( 11,24), ( 11,26), ( 11,28), ( 11,30)
( 11, 1), ( 11, 3), ( 11, 5), ( 11, 7), ( 11, 9), ( 11,11), ( 11,13), ( 11,15), ( 11,17), ( 11,19), ( 11,21), ( 11,23), ( 11,25), ( 11,27), ( 11,29), ( 11,31)
( 12, 0), ( 12, 2), ( 12, 4), ( 12, 6), ( 12, 8), ( 12,10), ( 12,12), ( 12,14), ( 12,16), ( 12,18), ( 12,20), ( 12,22), ( 12,24), ( 12,26), ( 12,28), ( 12,30)
( 12, 1), ( 12, 3), ( 12, 5), ( 12, 7), ( 12, 9), ( 12,11), ( 12,13), ( 12,15), ( 12,17), ( 12,19), ( 12,21), ( 12,23), ( 12,25), ( 12,27), ( 12,29), ( 12,31)
( 13, 0), ( 13, 2), ( 13, 4), ( 13, 6), ( 13, 8), ( 13,10), ( 13,12), ( 13,14), ( 13,16), ( 13,18), ( 13,20), ( 13,22), ( 13,24), ( 13,26), ( 13,28), ( 13,30)
( 13, 1), ( 13, 3), ( 13, 5), ( 13, 7), ( 13, 9), ( 13,11), ( 13,13), ( 13,15), ( 13,17), ( 13,19), ( 13,21), ( 13,23), ( 13,25), ( 13,27), ( 13,29), ( 13,31)
( 14, 0), ( 14, 2), ( 14, 4), ( 14, 6), ( 14, 8), ( 14,10), ( 14,12), ( 14,14), ( 14,16), ( 14,18), ( 14,20), ( 14,22), ( 14,24), ( 14,26), ( 14,28), ( 14,30)
( 14, 1), ( 14, 3), ( 14, 5), ( 14, 7), ( 14, 9), ( 14,11), ( 14,13), ( 14,15), ( 14,17), ( 14,19), ( 14,21), ( 14,23), ( 14,25), ( 14,27), ( 14,29), ( 14,31)
( 15, 0), ( 15, 2), ( 15, 4), ( 15, 6), ( 15, 8), ( 15,10), ( 15,12), ( 15,14), ( 15,16), ( 15,18), ( 15,20), ( 15,22), ( 15,24), ( 15,26), ( 15,28), ( 15,30)
( 15, 1), ( 15, 3), ( 15, 5), ( 15, 7), ( 15, 9), ( 15,11), ( 15,13), ( 15,15), ( 15,17), ( 15,19), ( 15,21), ( 15,23), ( 15,25), ( 15,27), ( 15,29), ( 15,31)
( 16, 0), ( 16, 2), ( 16, 4), ( 16, 6), ( 16, 8), ( 16,10), ( 16,12), ( 16,14), ( 16,16), ( 16,18), ( 16,20), ( 16,22), ( 16,24), ( 16,26), ( 16,28), ( 16,30)
( 16, 1), ( 16, 3), ( 16, 5), ( 16, 7), ( 16, 9), ( 16,11), ( 16,13), ( 16,15), ( 16,17), ( 16,19), ( 16,21), ( 16,23), ( 16,25), ( 16,27), ( 16,29), ( 16,31)
( 17, 0), ( 17, 2), ( 17, 4), ( 17, 6), ( 17, 8), ( 17,10), ( 17,12), ( 17,14), ( 17,16), ( 17,18), ( 17,20), ( 17,22), ( 17,24), ( 17,26), ( 17,28), ( 17,30)
( 17, 1), ( 17, 3), ( 17, 5), ( 17, 7), ( 17, 9), ( 17,11), ( 17,13), ( 17,15), ( 17,17), ( 17,19), ( 17,21), ( 17,23), ( 17,25), ( 17,27), ( 17,29), ( 17,31)
( 18, 0), ( 18, 2), ( 18, 4), ( 18, 6), ( 18, 8), ( 18,10), ( 18,12), ( 18,14), ( 18,16), ( 18,18), ( 18,20), ( 18,22), ( 18,24), ( 18,26), ( 18,28), ( 18,30)
( 18, 1), ( 18, 3), ( 18, 5), ( 18, 7), ( 18, 9), ( 18,11), ( 18,13), ( 18,15), ( 18,17), ( 18,19), ( 18,21), ( 18,23), ( 18,25), ( 18,27), ( 18,29), ( 18,31)
( 19, 0), ( 19, 2), ( 19, 4), ( 19, 6), ( 19, 8), ( 19,10), ( 19,12), ( 19,14), ( 19,16), ( 19,18), ( 19,20), ( 19,22), ( 19,24), ( 19,26), ( 19,28), ( 19,30)
( 19, 1), ( 19, 3), ( 19, 5), ( 19, 7), ( 19, 9), ( 19,11), ( 19,13), ( 19,15), ( 19,17), ( 19,19), ( 19,21), ( 19,23), ( 19,25), ( 19,27), ( 19,29), ( 19,31)
( 20, 0), ( 20, 2), ( 20, 4), ( 20, 6), ( 20, 8), ( 20,10), ( 20,12), ( 20,14), ( 20,16), ( 20,18), ( 20,20), ( 20,22), ( 20,24), ( 20,26), ( 20,28), ( 20,30)
( 20, 1), ( 20, 3), ( 20, 5), ( 20, 7), ( 20, 9), ( 20,11), ( 20,13), ( 20,15), ( 20,17), ( 20,19), ( 20,21), ( 20,23), ( 20,25), ( 20,27), ( 20,29), ( 20,31)
( 21, 0), ( 21, 2), ( 21, 4), ( 21, 6), ( 21, 8), ( 21,10), ( 21,12), ( 21,14), ( 21,16), ( 21,18), ( 21,20), ( 21,22), ( 21,24), ( 21,26), ( 21,28), ( 21,30)
( 21, 1), ( 21, 3), ( 21, 5), ( 21, 7), ( 21, 9), ( 21,11), ( 21,13), ( 21,15), ( 21,17), ( 21,19), ( 21,21), ( 21,23), ( 21,25), ( 21,27), ( 21,29), ( 21,31)
( 22, 0), ( 22, 2), ( 22, 4), ( 22, 6), ( 22, 8), ( 22,10), ( 22,12), ( 22,14), ( 22,16), ( 22,18), ( 22,20), ( 22,22), ( 22,24), ( 22,26), ( 22,28), ( 22,30)
( 22, 1), ( 22, 3), ( 22, 5), ( 22, 7), ( 22, 9), ( 22,11), ( 22,13), ( 22,15), ( 22,17), ( 22,19), ( 22,21), ( 22,23), ( 22,25), ( 22,27), ( 22,29), ( 22,31)
( 23, 0), ( 23, 2), ( 23, 4), ( 23, 6), ( 23, 8), ( 23,10), ( 23,12), ( 23,14), ( 23,16), ( 23,18), ( 23,20), ( 23,22), ( 23,24), ( 23,26), ( 23,28), ( 23,30)
( 23, 1), ( 23, 3), ( 23, 5), ( 23, 7), ( 23, 9), ( 23,11), ( 23,13), ( 23,15), ( 23,17), ( 23,19), ( 23,21), ( 23,23), ( 23,25), ( 23,27), ( 23,29), ( 23,31)
( 24, 0), ( 24, 2), ( 24, 4), ( 24, 6), ( 24, 8), ( 24,10), ( 24,12), ( 24,14), ( 24,16), ( 24,18), ( 24,20), ( 24,22), ( 24,24), ( 24,26), ( 24,28), ( 24,30)
( 24, 1), ( 24, 3), ( 24, 5), ( 24, 7), ( 24, 9), ( 24,11), ( 24,13), ( 24,15), ( 24,17), ( 24,19), ( 24,21), ( 24,23), ( 24,25), ( 24,27), ( 24,29), ( 24,31)
( 25, 0), ( 25, 2), ( 25, 4), ( 25, 6), ( 25, 8), ( 25,10), ( 25,12), ( 25,14), ( 25,16), ( 25,18), ( 25,20), ( 25,22), ( 25,24), ( 25,26), ( 25,28), ( 25,30)
( 25, 1), ( 25, 3), ( 25, 5), ( 25, 7), ( 25, 9), ( 25,11), ( 25,13), ( 25,15), ( 25,17), ( 25,19), ( 25,21), ( 25,23), ( 25,25), ( 25,27), ( 25,29), ( 25,31)
( 26, 0), ( 26, 2), ( 26, 4), ( 26, 6), ( 26, 8), ( 26,10), ( 26,12), ( 26,14), ( 26,16), ( 26,18), ( 26,20), ( 26,22), ( 26,24), ( 26,26), ( 26,28), ( 26,30)
( 26, 1), ( 26, 3), ( 26, 5), ( 26, 7), ( 26, 9), ( 26,11), ( 26,13), ( 26,15), ( 26,17), ( 26,19), ( 26,21), ( 26,23), ( 26,25), ( 26,27), ( 26,29), ( 26,31)
( 27, 0), ( 27, 2), ( 27, 4), ( 27, 6), ( 27, 8), ( 27,10), ( 27,12), ( 27,14), ( 27,16), ( 27,18), ( 27,20), ( 27,22), ( 27,24), ( 27,26), ( 27,28), ( 27,30)
( 27, 1), ( 27, 3), ( 27, 5), ( 27, 7), ( 27, 9), ( 27,11), ( 27,13), ( 27,15), ( 27,17), ( 27,19), ( 27,21), ( 27,23), ( 27,25), ( 27,27), ( 27,29), ( 27,31)
( 28, 0), ( 28, 2), ( 28, 4), ( 28, 6), ( 28, 8), ( 28,10), ( 28,12), ( 28,14), ( 28,16), ( 28,18), ( 28,20), ( 28,22), ( 28,24), ( 28,26), ( 28,28), ( 28,30)
( 28, 1), ( 28, 3), ( 28, 5), ( 28, 7), ( 28, 9), ( 28,11), ( 28,13), ( 28,15), ( 28,17), ( 28,19), ( 28,21), ( 28,23), ( 28,25), ( 28,27), ( 28,29), ( 28,31)
( 29, 0), ( 29, 2), ( 29, 4), ( 29, 6), ( 29, 8), ( 29,10), ( 29,12), ( 29,14), ( 29,16), ( 29,18), ( 29,20), ( 29,22), ( 29,24), ( 29,26), ( 29,28), ( 29,30)
( 29, 1), ( 29, 3), ( 29, 5), ( 29, 7), ( 29, 9), ( 29,11), ( 29,13), ( 29,15), ( 29,17), ( 29,19), ( 29,21), ( 29,23), ( 29,25), ( 29,27), ( 29,29), ( 29,31)
( 30, 0), ( 30, 2), ( 30, 4), ( 30, 6), ( 30, 8), ( 30,10), ( 30,12), ( 30,14), ( 30,16), ( 30,18), ( 30,20), ( 30,22), ( 30,24), ( 30,26), ( 30,28), ( 30,30)
( 30, 1), ( 30, 3), ( 30, 5), ( 30, 7), ( 30, 9), ( 30,11), ( 30,13), ( 30,15), ( 30,17), ( 30,19), ( 30,21), ( 30,23), ( 30,25), ( 30,27), ( 30,29), ( 30,31)
( 31, 0), ( 31, 2), ( 31, 4), ( 31, 6), ( 31, 8), ( 31,10), ( 31,12), ( 31,14), ( 31,16), ( 31,18), ( 31,20), ( 31,22), ( 31,24), ( 31,26), ( 31,28), ( 31,30)
( 31, 1), ( 31, 3), ( 31, 5), ( 31, 7), ( 31, 9), ( 31,11), ( 31,13), ( 31,15), ( 31,17), ( 31,19), ( 31,21), ( 31,23), ( 31,25), ( 31,27), ( 31,29), ( 31,31)
Note the interleaving with 32x32x1
because width = 2*subgroupSize
.
Does order affect this?
No, order is only used when determining the broadcast dimension for the warp/block layout. The tile params always refer to the row-major ordering expected in the subgroup 2d block io instructions.
introduce numBlocks, swap tile/instrShape and tensorShape
support B matrix with vblocks
use vblock layout for A matrix, disabled 32x32 test
Remove opIdx dependence for conversion, add order to the layout instead
cleanup some misc code and old definitions
support packed layouts with width > threads per warp
add i8 tests
ff4bfda
to
1bcf634
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks correct to me
LinearLayout | ||
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> shape, | ||
intel::Subgroup2DBlockEncodingAttr layout, | ||
unsigned kWidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think kWidth
is unused. It's also available as the op parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is unused, but it will probably be used later so I left it for now :)
// Increasing the block count always increases the inner dimension for the | ||
// register/lane layout regardless of order | ||
ctaLayout *= | ||
LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, ok, makes sense now. Here's the correct link btw https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview.html, those seem to be broken
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the test under third_party/intel/unittest/Dialect
?
add_triton_ut( | ||
NAME LinearLayoutConversionsIntel | ||
SRCS LinearLayoutConversionsTest.cpp | ||
LIBS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about TritonIntelGPUIR
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a DPAStoLinearLayoutTest.cpp
, how about call this subgroup2DBlockToLinearLayoutTest.cpp
?
Add a new layout to describe the tensor layout with respect to the GPU compute hierarchy (register, lane, warp, block). This PR introduces the layout and adds its definition and basic functions to the Triton Intel GPU Dialect. The conversion to Linear Layout function has been added and unit tested through an Intel specific
LinearLayoutConversionsTest
. The layouts are unpacked - each register is assumed to be the size of the tensor type. However, the layout generation follows the convention described in https://github.khronos.org/SPIRV-Registry/extensions/INTEL/[SPV_INTEL_2d_block_io](https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html).html. While there may be some bugs, the goal is for any valid operation described in the SPIRV extension to be represented correctly with this layout.Currently the layout is unused other than for linear layout conversion testing purposes. I plan to leave this PR in draft until I have replaced the
block_io
attribute on the load ops with this layout - and then I plan to replace the linear layout code I added toLoadStoreOpToLLVM.cpp
. That second task might prove challenging since I think the DPAS layouts do sometimes incorporate register packing schemes into the layout - but looking at the upstream layouts for NVIDIA and AMD MMA, specific packing is an implementation detail and not represented as part of the high-level layout encoding.cc #4192