Skip to content

Introduce Subgroup 2D Block Encoding #4193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
ArrayRef<int64_t> shape);

LinearLayout
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> shape,
intel::Subgroup2DBlockEncodingAttr layout,
unsigned kWidth);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think kWidth is unused. It's also available as the op parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unused, but it will probably be used later so I left it for now :)


} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONINTELGPU_IR_LINEARLAYOUTCONVERSIONS_H
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,47 @@ def WarpEncodingAttr : DistributedEncoding<"WarpEncoding", "intel_warp_encoding"
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Intel Subgroup2DBlock Encoding
//===----------------------------------------------------------------------===//

def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "subgroup_2d_block_encoding", [MmaEncodingTrait], TritonIntelGPU_Dialect> {
let mnemonic = "subgroup_2d_block";

let description = [{
An encoding for tensors produced via Intel Subgroup 2D Block IO operations.

The subgroup 2D block IO operations read or write two-dimensional blocks of data from a two-dimensional region of memory. The Subgroup 2D Block Encoding layout is parameterized by the block width, block height, and block count for the individual load instructions and the distribution and replication of loads across warps.

The SPV_INTEL_2d_block_io extension documentation provides more information on the subgroup 2D block IO operations and parameters: https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html

For the layout, the following parameters are required:
- `instrShape` : contains the (height, width) block parameters for the block io operation
- `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the instrShape=[16, 16] and numBlocks=2, then tensor is going to be split into [16, 32] blocks, right? Can I specify [32, 16] or [32, 32] block using the same instrShape? Does order affect this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the instrShape=[16, 16] and numBlocks=2, then tensor is going to be split into [16, 32] blocks, right?

Better to think of it as two consecutive [16, 16] blocks in row-major order.

Can I specify [32, 16] or [32, 32] block using the same instrShape?

You would get different blocks.
32x16x2:

Print layout attribute: #triton_intel_gpu.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks=2, order=[1, 0], kWidth=2, threadsPerWarp=16}>
Warp0:
(  0, 0), (  0, 1), (  0, 2), (  0, 3), (  0, 4), (  0, 5), (  0, 6), (  0, 7), (  0, 8), (  0, 9), (  0,10), (  0,11), (  0,12), (  0,13), (  0,14), (  0,15)
(  1, 0), (  1, 1), (  1, 2), (  1, 3), (  1, 4), (  1, 5), (  1, 6), (  1, 7), (  1, 8), (  1, 9), (  1,10), (  1,11), (  1,12), (  1,13), (  1,14), (  1,15)
(  2, 0), (  2, 1), (  2, 2), (  2, 3), (  2, 4), (  2, 5), (  2, 6), (  2, 7), (  2, 8), (  2, 9), (  2,10), (  2,11), (  2,12), (  2,13), (  2,14), (  2,15)
(  3, 0), (  3, 1), (  3, 2), (  3, 3), (  3, 4), (  3, 5), (  3, 6), (  3, 7), (  3, 8), (  3, 9), (  3,10), (  3,11), (  3,12), (  3,13), (  3,14), (  3,15)
(  4, 0), (  4, 1), (  4, 2), (  4, 3), (  4, 4), (  4, 5), (  4, 6), (  4, 7), (  4, 8), (  4, 9), (  4,10), (  4,11), (  4,12), (  4,13), (  4,14), (  4,15)
(  5, 0), (  5, 1), (  5, 2), (  5, 3), (  5, 4), (  5, 5), (  5, 6), (  5, 7), (  5, 8), (  5, 9), (  5,10), (  5,11), (  5,12), (  5,13), (  5,14), (  5,15)
(  6, 0), (  6, 1), (  6, 2), (  6, 3), (  6, 4), (  6, 5), (  6, 6), (  6, 7), (  6, 8), (  6, 9), (  6,10), (  6,11), (  6,12), (  6,13), (  6,14), (  6,15)
(  7, 0), (  7, 1), (  7, 2), (  7, 3), (  7, 4), (  7, 5), (  7, 6), (  7, 7), (  7, 8), (  7, 9), (  7,10), (  7,11), (  7,12), (  7,13), (  7,14), (  7,15)
(  8, 0), (  8, 1), (  8, 2), (  8, 3), (  8, 4), (  8, 5), (  8, 6), (  8, 7), (  8, 8), (  8, 9), (  8,10), (  8,11), (  8,12), (  8,13), (  8,14), (  8,15)
(  9, 0), (  9, 1), (  9, 2), (  9, 3), (  9, 4), (  9, 5), (  9, 6), (  9, 7), (  9, 8), (  9, 9), (  9,10), (  9,11), (  9,12), (  9,13), (  9,14), (  9,15)
( 10, 0), ( 10, 1), ( 10, 2), ( 10, 3), ( 10, 4), ( 10, 5), ( 10, 6), ( 10, 7), ( 10, 8), ( 10, 9), ( 10,10), ( 10,11), ( 10,12), ( 10,13), ( 10,14), ( 10,15)
( 11, 0), ( 11, 1), ( 11, 2), ( 11, 3), ( 11, 4), ( 11, 5), ( 11, 6), ( 11, 7), ( 11, 8), ( 11, 9), ( 11,10), ( 11,11), ( 11,12), ( 11,13), ( 11,14), ( 11,15)
( 12, 0), ( 12, 1), ( 12, 2), ( 12, 3), ( 12, 4), ( 12, 5), ( 12, 6), ( 12, 7), ( 12, 8), ( 12, 9), ( 12,10), ( 12,11), ( 12,12), ( 12,13), ( 12,14), ( 12,15)
( 13, 0), ( 13, 1), ( 13, 2), ( 13, 3), ( 13, 4), ( 13, 5), ( 13, 6), ( 13, 7), ( 13, 8), ( 13, 9), ( 13,10), ( 13,11), ( 13,12), ( 13,13), ( 13,14), ( 13,15)
( 14, 0), ( 14, 1), ( 14, 2), ( 14, 3), ( 14, 4), ( 14, 5), ( 14, 6), ( 14, 7), ( 14, 8), ( 14, 9), ( 14,10), ( 14,11), ( 14,12), ( 14,13), ( 14,14), ( 14,15)
( 15, 0), ( 15, 1), ( 15, 2), ( 15, 3), ( 15, 4), ( 15, 5), ( 15, 6), ( 15, 7), ( 15, 8), ( 15, 9), ( 15,10), ( 15,11), ( 15,12), ( 15,13), ( 15,14), ( 15,15)
( 16, 0), ( 16, 1), ( 16, 2), ( 16, 3), ( 16, 4), ( 16, 5), ( 16, 6), ( 16, 7), ( 16, 8), ( 16, 9), ( 16,10), ( 16,11), ( 16,12), ( 16,13), ( 16,14), ( 16,15)
( 17, 0), ( 17, 1), ( 17, 2), ( 17, 3), ( 17, 4), ( 17, 5), ( 17, 6), ( 17, 7), ( 17, 8), ( 17, 9), ( 17,10), ( 17,11), ( 17,12), ( 17,13), ( 17,14), ( 17,15)
( 18, 0), ( 18, 1), ( 18, 2), ( 18, 3), ( 18, 4), ( 18, 5), ( 18, 6), ( 18, 7), ( 18, 8), ( 18, 9), ( 18,10), ( 18,11), ( 18,12), ( 18,13), ( 18,14), ( 18,15)
( 19, 0), ( 19, 1), ( 19, 2), ( 19, 3), ( 19, 4), ( 19, 5), ( 19, 6), ( 19, 7), ( 19, 8), ( 19, 9), ( 19,10), ( 19,11), ( 19,12), ( 19,13), ( 19,14), ( 19,15)
( 20, 0), ( 20, 1), ( 20, 2), ( 20, 3), ( 20, 4), ( 20, 5), ( 20, 6), ( 20, 7), ( 20, 8), ( 20, 9), ( 20,10), ( 20,11), ( 20,12), ( 20,13), ( 20,14), ( 20,15)
( 21, 0), ( 21, 1), ( 21, 2), ( 21, 3), ( 21, 4), ( 21, 5), ( 21, 6), ( 21, 7), ( 21, 8), ( 21, 9), ( 21,10), ( 21,11), ( 21,12), ( 21,13), ( 21,14), ( 21,15)
( 22, 0), ( 22, 1), ( 22, 2), ( 22, 3), ( 22, 4), ( 22, 5), ( 22, 6), ( 22, 7), ( 22, 8), ( 22, 9), ( 22,10), ( 22,11), ( 22,12), ( 22,13), ( 22,14), ( 22,15)
( 23, 0), ( 23, 1), ( 23, 2), ( 23, 3), ( 23, 4), ( 23, 5), ( 23, 6), ( 23, 7), ( 23, 8), ( 23, 9), ( 23,10), ( 23,11), ( 23,12), ( 23,13), ( 23,14), ( 23,15)
( 24, 0), ( 24, 1), ( 24, 2), ( 24, 3), ( 24, 4), ( 24, 5), ( 24, 6), ( 24, 7), ( 24, 8), ( 24, 9), ( 24,10), ( 24,11), ( 24,12), ( 24,13), ( 24,14), ( 24,15)
( 25, 0), ( 25, 1), ( 25, 2), ( 25, 3), ( 25, 4), ( 25, 5), ( 25, 6), ( 25, 7), ( 25, 8), ( 25, 9), ( 25,10), ( 25,11), ( 25,12), ( 25,13), ( 25,14), ( 25,15)
( 26, 0), ( 26, 1), ( 26, 2), ( 26, 3), ( 26, 4), ( 26, 5), ( 26, 6), ( 26, 7), ( 26, 8), ( 26, 9), ( 26,10), ( 26,11), ( 26,12), ( 26,13), ( 26,14), ( 26,15)
( 27, 0), ( 27, 1), ( 27, 2), ( 27, 3), ( 27, 4), ( 27, 5), ( 27, 6), ( 27, 7), ( 27, 8), ( 27, 9), ( 27,10), ( 27,11), ( 27,12), ( 27,13), ( 27,14), ( 27,15)
( 28, 0), ( 28, 1), ( 28, 2), ( 28, 3), ( 28, 4), ( 28, 5), ( 28, 6), ( 28, 7), ( 28, 8), ( 28, 9), ( 28,10), ( 28,11), ( 28,12), ( 28,13), ( 28,14), ( 28,15)
( 29, 0), ( 29, 1), ( 29, 2), ( 29, 3), ( 29, 4), ( 29, 5), ( 29, 6), ( 29, 7), ( 29, 8), ( 29, 9), ( 29,10), ( 29,11), ( 29,12), ( 29,13), ( 29,14), ( 29,15)
( 30, 0), ( 30, 1), ( 30, 2), ( 30, 3), ( 30, 4), ( 30, 5), ( 30, 6), ( 30, 7), ( 30, 8), ( 30, 9), ( 30,10), ( 30,11), ( 30,12), ( 30,13), ( 30,14), ( 30,15)
( 31, 0), ( 31, 1), ( 31, 2), ( 31, 3), ( 31, 4), ( 31, 5), ( 31, 6), ( 31, 7), ( 31, 8), ( 31, 9), ( 31,10), ( 31,11), ( 31,12), ( 31,13), ( 31,14), ( 31,15)
(  0,16), (  0,17), (  0,18), (  0,19), (  0,20), (  0,21), (  0,22), (  0,23), (  0,24), (  0,25), (  0,26), (  0,27), (  0,28), (  0,29), (  0,30), (  0,31)
(  1,16), (  1,17), (  1,18), (  1,19), (  1,20), (  1,21), (  1,22), (  1,23), (  1,24), (  1,25), (  1,26), (  1,27), (  1,28), (  1,29), (  1,30), (  1,31)
(  2,16), (  2,17), (  2,18), (  2,19), (  2,20), (  2,21), (  2,22), (  2,23), (  2,24), (  2,25), (  2,26), (  2,27), (  2,28), (  2,29), (  2,30), (  2,31)
(  3,16), (  3,17), (  3,18), (  3,19), (  3,20), (  3,21), (  3,22), (  3,23), (  3,24), (  3,25), (  3,26), (  3,27), (  3,28), (  3,29), (  3,30), (  3,31)
(  4,16), (  4,17), (  4,18), (  4,19), (  4,20), (  4,21), (  4,22), (  4,23), (  4,24), (  4,25), (  4,26), (  4,27), (  4,28), (  4,29), (  4,30), (  4,31)
(  5,16), (  5,17), (  5,18), (  5,19), (  5,20), (  5,21), (  5,22), (  5,23), (  5,24), (  5,25), (  5,26), (  5,27), (  5,28), (  5,29), (  5,30), (  5,31)
(  6,16), (  6,17), (  6,18), (  6,19), (  6,20), (  6,21), (  6,22), (  6,23), (  6,24), (  6,25), (  6,26), (  6,27), (  6,28), (  6,29), (  6,30), (  6,31)
(  7,16), (  7,17), (  7,18), (  7,19), (  7,20), (  7,21), (  7,22), (  7,23), (  7,24), (  7,25), (  7,26), (  7,27), (  7,28), (  7,29), (  7,30), (  7,31)
(  8,16), (  8,17), (  8,18), (  8,19), (  8,20), (  8,21), (  8,22), (  8,23), (  8,24), (  8,25), (  8,26), (  8,27), (  8,28), (  8,29), (  8,30), (  8,31)
(  9,16), (  9,17), (  9,18), (  9,19), (  9,20), (  9,21), (  9,22), (  9,23), (  9,24), (  9,25), (  9,26), (  9,27), (  9,28), (  9,29), (  9,30), (  9,31)
( 10,16), ( 10,17), ( 10,18), ( 10,19), ( 10,20), ( 10,21), ( 10,22), ( 10,23), ( 10,24), ( 10,25), ( 10,26), ( 10,27), ( 10,28), ( 10,29), ( 10,30), ( 10,31)
( 11,16), ( 11,17), ( 11,18), ( 11,19), ( 11,20), ( 11,21), ( 11,22), ( 11,23), ( 11,24), ( 11,25), ( 11,26), ( 11,27), ( 11,28), ( 11,29), ( 11,30), ( 11,31)
( 12,16), ( 12,17), ( 12,18), ( 12,19), ( 12,20), ( 12,21), ( 12,22), ( 12,23), ( 12,24), ( 12,25), ( 12,26), ( 12,27), ( 12,28), ( 12,29), ( 12,30), ( 12,31)
( 13,16), ( 13,17), ( 13,18), ( 13,19), ( 13,20), ( 13,21), ( 13,22), ( 13,23), ( 13,24), ( 13,25), ( 13,26), ( 13,27), ( 13,28), ( 13,29), ( 13,30), ( 13,31)
( 14,16), ( 14,17), ( 14,18), ( 14,19), ( 14,20), ( 14,21), ( 14,22), ( 14,23), ( 14,24), ( 14,25), ( 14,26), ( 14,27), ( 14,28), ( 14,29), ( 14,30), ( 14,31)
( 15,16), ( 15,17), ( 15,18), ( 15,19), ( 15,20), ( 15,21), ( 15,22), ( 15,23), ( 15,24), ( 15,25), ( 15,26), ( 15,27), ( 15,28), ( 15,29), ( 15,30), ( 15,31)
( 16,16), ( 16,17), ( 16,18), ( 16,19), ( 16,20), ( 16,21), ( 16,22), ( 16,23), ( 16,24), ( 16,25), ( 16,26), ( 16,27), ( 16,28), ( 16,29), ( 16,30), ( 16,31)
( 17,16), ( 17,17), ( 17,18), ( 17,19), ( 17,20), ( 17,21), ( 17,22), ( 17,23), ( 17,24), ( 17,25), ( 17,26), ( 17,27), ( 17,28), ( 17,29), ( 17,30), ( 17,31)
( 18,16), ( 18,17), ( 18,18), ( 18,19), ( 18,20), ( 18,21), ( 18,22), ( 18,23), ( 18,24), ( 18,25), ( 18,26), ( 18,27), ( 18,28), ( 18,29), ( 18,30), ( 18,31)
( 19,16), ( 19,17), ( 19,18), ( 19,19), ( 19,20), ( 19,21), ( 19,22), ( 19,23), ( 19,24), ( 19,25), ( 19,26), ( 19,27), ( 19,28), ( 19,29), ( 19,30), ( 19,31)
( 20,16), ( 20,17), ( 20,18), ( 20,19), ( 20,20), ( 20,21), ( 20,22), ( 20,23), ( 20,24), ( 20,25), ( 20,26), ( 20,27), ( 20,28), ( 20,29), ( 20,30), ( 20,31)
( 21,16), ( 21,17), ( 21,18), ( 21,19), ( 21,20), ( 21,21), ( 21,22), ( 21,23), ( 21,24), ( 21,25), ( 21,26), ( 21,27), ( 21,28), ( 21,29), ( 21,30), ( 21,31)
( 22,16), ( 22,17), ( 22,18), ( 22,19), ( 22,20), ( 22,21), ( 22,22), ( 22,23), ( 22,24), ( 22,25), ( 22,26), ( 22,27), ( 22,28), ( 22,29), ( 22,30), ( 22,31)
( 23,16), ( 23,17), ( 23,18), ( 23,19), ( 23,20), ( 23,21), ( 23,22), ( 23,23), ( 23,24), ( 23,25), ( 23,26), ( 23,27), ( 23,28), ( 23,29), ( 23,30), ( 23,31)
( 24,16), ( 24,17), ( 24,18), ( 24,19), ( 24,20), ( 24,21), ( 24,22), ( 24,23), ( 24,24), ( 24,25), ( 24,26), ( 24,27), ( 24,28), ( 24,29), ( 24,30), ( 24,31)
( 25,16), ( 25,17), ( 25,18), ( 25,19), ( 25,20), ( 25,21), ( 25,22), ( 25,23), ( 25,24), ( 25,25), ( 25,26), ( 25,27), ( 25,28), ( 25,29), ( 25,30), ( 25,31)
( 26,16), ( 26,17), ( 26,18), ( 26,19), ( 26,20), ( 26,21), ( 26,22), ( 26,23), ( 26,24), ( 26,25), ( 26,26), ( 26,27), ( 26,28), ( 26,29), ( 26,30), ( 26,31)
( 27,16), ( 27,17), ( 27,18), ( 27,19), ( 27,20), ( 27,21), ( 27,22), ( 27,23), ( 27,24), ( 27,25), ( 27,26), ( 27,27), ( 27,28), ( 27,29), ( 27,30), ( 27,31)
( 28,16), ( 28,17), ( 28,18), ( 28,19), ( 28,20), ( 28,21), ( 28,22), ( 28,23), ( 28,24), ( 28,25), ( 28,26), ( 28,27), ( 28,28), ( 28,29), ( 28,30), ( 28,31)
( 29,16), ( 29,17), ( 29,18), ( 29,19), ( 29,20), ( 29,21), ( 29,22), ( 29,23), ( 29,24), ( 29,25), ( 29,26), ( 29,27), ( 29,28), ( 29,29), ( 29,30), ( 29,31)
( 30,16), ( 30,17), ( 30,18), ( 30,19), ( 30,20), ( 30,21), ( 30,22), ( 30,23), ( 30,24), ( 30,25), ( 30,26), ( 30,27), ( 30,28), ( 30,29), ( 30,30), ( 30,31)
( 31,16), ( 31,17), ( 31,18), ( 31,19), ( 31,20), ( 31,21), ( 31,22), ( 31,23), ( 31,24), ( 31,25), ( 31,26), ( 31,27), ( 31,28), ( 31,29), ( 31,30), ( 31,31)

vs 32x32x1

Print layout attribute: #triton_intel_gpu.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 32], numBlocks=1, order=[1, 0], kWidth=2, threadsPerWarp=16}>
Warp0:
(  0, 0), (  0, 2), (  0, 4), (  0, 6), (  0, 8), (  0,10), (  0,12), (  0,14), (  0,16), (  0,18), (  0,20), (  0,22), (  0,24), (  0,26), (  0,28), (  0,30)
(  0, 1), (  0, 3), (  0, 5), (  0, 7), (  0, 9), (  0,11), (  0,13), (  0,15), (  0,17), (  0,19), (  0,21), (  0,23), (  0,25), (  0,27), (  0,29), (  0,31)
(  1, 0), (  1, 2), (  1, 4), (  1, 6), (  1, 8), (  1,10), (  1,12), (  1,14), (  1,16), (  1,18), (  1,20), (  1,22), (  1,24), (  1,26), (  1,28), (  1,30)
(  1, 1), (  1, 3), (  1, 5), (  1, 7), (  1, 9), (  1,11), (  1,13), (  1,15), (  1,17), (  1,19), (  1,21), (  1,23), (  1,25), (  1,27), (  1,29), (  1,31)
(  2, 0), (  2, 2), (  2, 4), (  2, 6), (  2, 8), (  2,10), (  2,12), (  2,14), (  2,16), (  2,18), (  2,20), (  2,22), (  2,24), (  2,26), (  2,28), (  2,30)
(  2, 1), (  2, 3), (  2, 5), (  2, 7), (  2, 9), (  2,11), (  2,13), (  2,15), (  2,17), (  2,19), (  2,21), (  2,23), (  2,25), (  2,27), (  2,29), (  2,31)
(  3, 0), (  3, 2), (  3, 4), (  3, 6), (  3, 8), (  3,10), (  3,12), (  3,14), (  3,16), (  3,18), (  3,20), (  3,22), (  3,24), (  3,26), (  3,28), (  3,30)
(  3, 1), (  3, 3), (  3, 5), (  3, 7), (  3, 9), (  3,11), (  3,13), (  3,15), (  3,17), (  3,19), (  3,21), (  3,23), (  3,25), (  3,27), (  3,29), (  3,31)
(  4, 0), (  4, 2), (  4, 4), (  4, 6), (  4, 8), (  4,10), (  4,12), (  4,14), (  4,16), (  4,18), (  4,20), (  4,22), (  4,24), (  4,26), (  4,28), (  4,30)
(  4, 1), (  4, 3), (  4, 5), (  4, 7), (  4, 9), (  4,11), (  4,13), (  4,15), (  4,17), (  4,19), (  4,21), (  4,23), (  4,25), (  4,27), (  4,29), (  4,31)
(  5, 0), (  5, 2), (  5, 4), (  5, 6), (  5, 8), (  5,10), (  5,12), (  5,14), (  5,16), (  5,18), (  5,20), (  5,22), (  5,24), (  5,26), (  5,28), (  5,30)
(  5, 1), (  5, 3), (  5, 5), (  5, 7), (  5, 9), (  5,11), (  5,13), (  5,15), (  5,17), (  5,19), (  5,21), (  5,23), (  5,25), (  5,27), (  5,29), (  5,31)
(  6, 0), (  6, 2), (  6, 4), (  6, 6), (  6, 8), (  6,10), (  6,12), (  6,14), (  6,16), (  6,18), (  6,20), (  6,22), (  6,24), (  6,26), (  6,28), (  6,30)
(  6, 1), (  6, 3), (  6, 5), (  6, 7), (  6, 9), (  6,11), (  6,13), (  6,15), (  6,17), (  6,19), (  6,21), (  6,23), (  6,25), (  6,27), (  6,29), (  6,31)
(  7, 0), (  7, 2), (  7, 4), (  7, 6), (  7, 8), (  7,10), (  7,12), (  7,14), (  7,16), (  7,18), (  7,20), (  7,22), (  7,24), (  7,26), (  7,28), (  7,30)
(  7, 1), (  7, 3), (  7, 5), (  7, 7), (  7, 9), (  7,11), (  7,13), (  7,15), (  7,17), (  7,19), (  7,21), (  7,23), (  7,25), (  7,27), (  7,29), (  7,31)
(  8, 0), (  8, 2), (  8, 4), (  8, 6), (  8, 8), (  8,10), (  8,12), (  8,14), (  8,16), (  8,18), (  8,20), (  8,22), (  8,24), (  8,26), (  8,28), (  8,30)
(  8, 1), (  8, 3), (  8, 5), (  8, 7), (  8, 9), (  8,11), (  8,13), (  8,15), (  8,17), (  8,19), (  8,21), (  8,23), (  8,25), (  8,27), (  8,29), (  8,31)
(  9, 0), (  9, 2), (  9, 4), (  9, 6), (  9, 8), (  9,10), (  9,12), (  9,14), (  9,16), (  9,18), (  9,20), (  9,22), (  9,24), (  9,26), (  9,28), (  9,30)
(  9, 1), (  9, 3), (  9, 5), (  9, 7), (  9, 9), (  9,11), (  9,13), (  9,15), (  9,17), (  9,19), (  9,21), (  9,23), (  9,25), (  9,27), (  9,29), (  9,31)
( 10, 0), ( 10, 2), ( 10, 4), ( 10, 6), ( 10, 8), ( 10,10), ( 10,12), ( 10,14), ( 10,16), ( 10,18), ( 10,20), ( 10,22), ( 10,24), ( 10,26), ( 10,28), ( 10,30)
( 10, 1), ( 10, 3), ( 10, 5), ( 10, 7), ( 10, 9), ( 10,11), ( 10,13), ( 10,15), ( 10,17), ( 10,19), ( 10,21), ( 10,23), ( 10,25), ( 10,27), ( 10,29), ( 10,31)
( 11, 0), ( 11, 2), ( 11, 4), ( 11, 6), ( 11, 8), ( 11,10), ( 11,12), ( 11,14), ( 11,16), ( 11,18), ( 11,20), ( 11,22), ( 11,24), ( 11,26), ( 11,28), ( 11,30)
( 11, 1), ( 11, 3), ( 11, 5), ( 11, 7), ( 11, 9), ( 11,11), ( 11,13), ( 11,15), ( 11,17), ( 11,19), ( 11,21), ( 11,23), ( 11,25), ( 11,27), ( 11,29), ( 11,31)
( 12, 0), ( 12, 2), ( 12, 4), ( 12, 6), ( 12, 8), ( 12,10), ( 12,12), ( 12,14), ( 12,16), ( 12,18), ( 12,20), ( 12,22), ( 12,24), ( 12,26), ( 12,28), ( 12,30)
( 12, 1), ( 12, 3), ( 12, 5), ( 12, 7), ( 12, 9), ( 12,11), ( 12,13), ( 12,15), ( 12,17), ( 12,19), ( 12,21), ( 12,23), ( 12,25), ( 12,27), ( 12,29), ( 12,31)
( 13, 0), ( 13, 2), ( 13, 4), ( 13, 6), ( 13, 8), ( 13,10), ( 13,12), ( 13,14), ( 13,16), ( 13,18), ( 13,20), ( 13,22), ( 13,24), ( 13,26), ( 13,28), ( 13,30)
( 13, 1), ( 13, 3), ( 13, 5), ( 13, 7), ( 13, 9), ( 13,11), ( 13,13), ( 13,15), ( 13,17), ( 13,19), ( 13,21), ( 13,23), ( 13,25), ( 13,27), ( 13,29), ( 13,31)
( 14, 0), ( 14, 2), ( 14, 4), ( 14, 6), ( 14, 8), ( 14,10), ( 14,12), ( 14,14), ( 14,16), ( 14,18), ( 14,20), ( 14,22), ( 14,24), ( 14,26), ( 14,28), ( 14,30)
( 14, 1), ( 14, 3), ( 14, 5), ( 14, 7), ( 14, 9), ( 14,11), ( 14,13), ( 14,15), ( 14,17), ( 14,19), ( 14,21), ( 14,23), ( 14,25), ( 14,27), ( 14,29), ( 14,31)
( 15, 0), ( 15, 2), ( 15, 4), ( 15, 6), ( 15, 8), ( 15,10), ( 15,12), ( 15,14), ( 15,16), ( 15,18), ( 15,20), ( 15,22), ( 15,24), ( 15,26), ( 15,28), ( 15,30)
( 15, 1), ( 15, 3), ( 15, 5), ( 15, 7), ( 15, 9), ( 15,11), ( 15,13), ( 15,15), ( 15,17), ( 15,19), ( 15,21), ( 15,23), ( 15,25), ( 15,27), ( 15,29), ( 15,31)
( 16, 0), ( 16, 2), ( 16, 4), ( 16, 6), ( 16, 8), ( 16,10), ( 16,12), ( 16,14), ( 16,16), ( 16,18), ( 16,20), ( 16,22), ( 16,24), ( 16,26), ( 16,28), ( 16,30)
( 16, 1), ( 16, 3), ( 16, 5), ( 16, 7), ( 16, 9), ( 16,11), ( 16,13), ( 16,15), ( 16,17), ( 16,19), ( 16,21), ( 16,23), ( 16,25), ( 16,27), ( 16,29), ( 16,31)
( 17, 0), ( 17, 2), ( 17, 4), ( 17, 6), ( 17, 8), ( 17,10), ( 17,12), ( 17,14), ( 17,16), ( 17,18), ( 17,20), ( 17,22), ( 17,24), ( 17,26), ( 17,28), ( 17,30)
( 17, 1), ( 17, 3), ( 17, 5), ( 17, 7), ( 17, 9), ( 17,11), ( 17,13), ( 17,15), ( 17,17), ( 17,19), ( 17,21), ( 17,23), ( 17,25), ( 17,27), ( 17,29), ( 17,31)
( 18, 0), ( 18, 2), ( 18, 4), ( 18, 6), ( 18, 8), ( 18,10), ( 18,12), ( 18,14), ( 18,16), ( 18,18), ( 18,20), ( 18,22), ( 18,24), ( 18,26), ( 18,28), ( 18,30)
( 18, 1), ( 18, 3), ( 18, 5), ( 18, 7), ( 18, 9), ( 18,11), ( 18,13), ( 18,15), ( 18,17), ( 18,19), ( 18,21), ( 18,23), ( 18,25), ( 18,27), ( 18,29), ( 18,31)
( 19, 0), ( 19, 2), ( 19, 4), ( 19, 6), ( 19, 8), ( 19,10), ( 19,12), ( 19,14), ( 19,16), ( 19,18), ( 19,20), ( 19,22), ( 19,24), ( 19,26), ( 19,28), ( 19,30)
( 19, 1), ( 19, 3), ( 19, 5), ( 19, 7), ( 19, 9), ( 19,11), ( 19,13), ( 19,15), ( 19,17), ( 19,19), ( 19,21), ( 19,23), ( 19,25), ( 19,27), ( 19,29), ( 19,31)
( 20, 0), ( 20, 2), ( 20, 4), ( 20, 6), ( 20, 8), ( 20,10), ( 20,12), ( 20,14), ( 20,16), ( 20,18), ( 20,20), ( 20,22), ( 20,24), ( 20,26), ( 20,28), ( 20,30)
( 20, 1), ( 20, 3), ( 20, 5), ( 20, 7), ( 20, 9), ( 20,11), ( 20,13), ( 20,15), ( 20,17), ( 20,19), ( 20,21), ( 20,23), ( 20,25), ( 20,27), ( 20,29), ( 20,31)
( 21, 0), ( 21, 2), ( 21, 4), ( 21, 6), ( 21, 8), ( 21,10), ( 21,12), ( 21,14), ( 21,16), ( 21,18), ( 21,20), ( 21,22), ( 21,24), ( 21,26), ( 21,28), ( 21,30)
( 21, 1), ( 21, 3), ( 21, 5), ( 21, 7), ( 21, 9), ( 21,11), ( 21,13), ( 21,15), ( 21,17), ( 21,19), ( 21,21), ( 21,23), ( 21,25), ( 21,27), ( 21,29), ( 21,31)
( 22, 0), ( 22, 2), ( 22, 4), ( 22, 6), ( 22, 8), ( 22,10), ( 22,12), ( 22,14), ( 22,16), ( 22,18), ( 22,20), ( 22,22), ( 22,24), ( 22,26), ( 22,28), ( 22,30)
( 22, 1), ( 22, 3), ( 22, 5), ( 22, 7), ( 22, 9), ( 22,11), ( 22,13), ( 22,15), ( 22,17), ( 22,19), ( 22,21), ( 22,23), ( 22,25), ( 22,27), ( 22,29), ( 22,31)
( 23, 0), ( 23, 2), ( 23, 4), ( 23, 6), ( 23, 8), ( 23,10), ( 23,12), ( 23,14), ( 23,16), ( 23,18), ( 23,20), ( 23,22), ( 23,24), ( 23,26), ( 23,28), ( 23,30)
( 23, 1), ( 23, 3), ( 23, 5), ( 23, 7), ( 23, 9), ( 23,11), ( 23,13), ( 23,15), ( 23,17), ( 23,19), ( 23,21), ( 23,23), ( 23,25), ( 23,27), ( 23,29), ( 23,31)
( 24, 0), ( 24, 2), ( 24, 4), ( 24, 6), ( 24, 8), ( 24,10), ( 24,12), ( 24,14), ( 24,16), ( 24,18), ( 24,20), ( 24,22), ( 24,24), ( 24,26), ( 24,28), ( 24,30)
( 24, 1), ( 24, 3), ( 24, 5), ( 24, 7), ( 24, 9), ( 24,11), ( 24,13), ( 24,15), ( 24,17), ( 24,19), ( 24,21), ( 24,23), ( 24,25), ( 24,27), ( 24,29), ( 24,31)
( 25, 0), ( 25, 2), ( 25, 4), ( 25, 6), ( 25, 8), ( 25,10), ( 25,12), ( 25,14), ( 25,16), ( 25,18), ( 25,20), ( 25,22), ( 25,24), ( 25,26), ( 25,28), ( 25,30)
( 25, 1), ( 25, 3), ( 25, 5), ( 25, 7), ( 25, 9), ( 25,11), ( 25,13), ( 25,15), ( 25,17), ( 25,19), ( 25,21), ( 25,23), ( 25,25), ( 25,27), ( 25,29), ( 25,31)
( 26, 0), ( 26, 2), ( 26, 4), ( 26, 6), ( 26, 8), ( 26,10), ( 26,12), ( 26,14), ( 26,16), ( 26,18), ( 26,20), ( 26,22), ( 26,24), ( 26,26), ( 26,28), ( 26,30)
( 26, 1), ( 26, 3), ( 26, 5), ( 26, 7), ( 26, 9), ( 26,11), ( 26,13), ( 26,15), ( 26,17), ( 26,19), ( 26,21), ( 26,23), ( 26,25), ( 26,27), ( 26,29), ( 26,31)
( 27, 0), ( 27, 2), ( 27, 4), ( 27, 6), ( 27, 8), ( 27,10), ( 27,12), ( 27,14), ( 27,16), ( 27,18), ( 27,20), ( 27,22), ( 27,24), ( 27,26), ( 27,28), ( 27,30)
( 27, 1), ( 27, 3), ( 27, 5), ( 27, 7), ( 27, 9), ( 27,11), ( 27,13), ( 27,15), ( 27,17), ( 27,19), ( 27,21), ( 27,23), ( 27,25), ( 27,27), ( 27,29), ( 27,31)
( 28, 0), ( 28, 2), ( 28, 4), ( 28, 6), ( 28, 8), ( 28,10), ( 28,12), ( 28,14), ( 28,16), ( 28,18), ( 28,20), ( 28,22), ( 28,24), ( 28,26), ( 28,28), ( 28,30)
( 28, 1), ( 28, 3), ( 28, 5), ( 28, 7), ( 28, 9), ( 28,11), ( 28,13), ( 28,15), ( 28,17), ( 28,19), ( 28,21), ( 28,23), ( 28,25), ( 28,27), ( 28,29), ( 28,31)
( 29, 0), ( 29, 2), ( 29, 4), ( 29, 6), ( 29, 8), ( 29,10), ( 29,12), ( 29,14), ( 29,16), ( 29,18), ( 29,20), ( 29,22), ( 29,24), ( 29,26), ( 29,28), ( 29,30)
( 29, 1), ( 29, 3), ( 29, 5), ( 29, 7), ( 29, 9), ( 29,11), ( 29,13), ( 29,15), ( 29,17), ( 29,19), ( 29,21), ( 29,23), ( 29,25), ( 29,27), ( 29,29), ( 29,31)
( 30, 0), ( 30, 2), ( 30, 4), ( 30, 6), ( 30, 8), ( 30,10), ( 30,12), ( 30,14), ( 30,16), ( 30,18), ( 30,20), ( 30,22), ( 30,24), ( 30,26), ( 30,28), ( 30,30)
( 30, 1), ( 30, 3), ( 30, 5), ( 30, 7), ( 30, 9), ( 30,11), ( 30,13), ( 30,15), ( 30,17), ( 30,19), ( 30,21), ( 30,23), ( 30,25), ( 30,27), ( 30,29), ( 30,31)
( 31, 0), ( 31, 2), ( 31, 4), ( 31, 6), ( 31, 8), ( 31,10), ( 31,12), ( 31,14), ( 31,16), ( 31,18), ( 31,20), ( 31,22), ( 31,24), ( 31,26), ( 31,28), ( 31,30)
( 31, 1), ( 31, 3), ( 31, 5), ( 31, 7), ( 31, 9), ( 31,11), ( 31,13), ( 31,15), ( 31,17), ( 31,19), ( 31,21), ( 31,23), ( 31,25), ( 31,27), ( 31,29), ( 31,31)

Note the interleaving with 32x32x1 because width = 2*subgroupSize.

Does order affect this?

No, order is only used when determining the broadcast dimension for the warp/block layout. The tile params always refer to the row-major ordering expected in the subgroup 2d block io instructions.

- `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor.
- `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution
- `order` : The order within the block, used to determine along which dimension to broadcast.
- `kWidth` : Currently unused, but keeping because we will likely need it for layout conversions.
- `CTALayout` : Describes how blocks are distributed among work-groups/thread blocks.
}];

let parameters = (
ins
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"CTALayoutAttr":$CTALayout,
ArrayRefParameter<"unsigned">:$instrShape,
"unsigned":$numBlocks,
ArrayRefParameter<"unsigned">:$order,
"unsigned":$kWidth,
"unsigned":$threadsPerWarp
);

let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
}];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}

#endif
167 changes: 167 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,173 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
<< "}>";
}

//===----------------------------------------------------------------------===//
// Subgroup2DBlockEncodingAttr
//===----------------------------------------------------------------------===//

namespace {
std::optional<CTALayoutAttr> getCTALayoutOrError(
AsmParser &parser, std::optional<SmallVector<unsigned>> CTAsPerCGA,
std::optional<SmallVector<unsigned>> CTASplitNum,
std::optional<SmallVector<unsigned>> CTAOrder, unsigned rank) {
if (CTAsPerCGA && CTASplitNum && CTAOrder) {
return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum,
*CTAOrder);
}
if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) {
return CTALayoutAttr::getDefault(parser.getContext(), rank);
}
parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder "
"must all be present or all be absent");
return std::nullopt;
}

// Print the CTALayout if it's not equal to the default.
void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
CTALayoutAttr layout, unsigned rank) {
if (layout != CTALayoutAttr::getDefault(context, rank)) {
printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]"
<< ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]"
<< ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]";
}
}

} // namespace

LogicalResult Subgroup2DBlockEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
ArrayRef<unsigned> warpsPerCTA, CTALayoutAttr CTALayout,
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order,
unsigned kWidth, unsigned threadsPerWarp) {
if (instrShape.size() != 2) {
return emitError() << "instrShape must be rank 2 but was: "
<< instrShape.size();
}
if (order.size() != 2) {
return emitError() << "order must be rank 2 but was " << order.size();
}
if (warpsPerCTA.size() != 2) {
return emitError() << "warpsPerCTA must be rank 2 but was "
<< warpsPerCTA.size();
}
if (!(kWidth == 1 || kWidth == 2 || kWidth == 4)) {
return emitError() << "kWidth must be 1, 2 or 4, but was: " << kWidth;
}
if (!threadsPerWarp == 16) {
return emitError() << "threadsPerWarp must be 16, but was: "
<< threadsPerWarp;
}
return success();
}

Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};

SmallVector<unsigned> warpsPerCTA;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
SmallVector<unsigned> instrShape;
unsigned numBlocks = 0;
SmallVector<unsigned> order;
unsigned kWidth = 0;
unsigned threadsPerWarp = 0;

for (const NamedAttribute &attr : dict) {
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
}
if (attr.getName() == "instrShape") {
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
return {};
}
if (attr.getName() == "numBlocks") {
if (parseUInt(parser, attr, numBlocks, "numBlocks").failed())
return {};
}
if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
}
if (attr.getName() == "kWidth") {
if (parseUInt(parser, attr, kWidth, "kWidth").failed())
return {};
}
if (attr.getName() == "threadsPerWarp") {
if (parseUInt(parser, attr, threadsPerWarp, "threadsPerWarp").failed())
return {};
}
}

std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
if (!CTALayout.has_value())
return {};

return parser.getChecked<Subgroup2DBlockEncodingAttr>(
parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks,
order, kWidth, threadsPerWarp);
}

SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getRepOrder() const {
return getMatrixOrder(getRank(), /*rowMajor*/ true);
}

SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}

SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTAOrder() const {
return SmallVector<unsigned>(getCTALayout().getCTAOrder());
}

SmallVector<unsigned> Subgroup2DBlockEncodingAttr::getCTASplitNum() const {
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
}

SmallVector<unsigned>
Subgroup2DBlockEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true);
}

void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{" << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]";

maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank());

printer << ", instrShape = [" << getInstrShape()
<< "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder()
<< "], kWidth=" << getKWidth()
<< ", threadsPerWarp=" << getThreadsPerWarp() << "}>";
}

LinearLayout
Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return subgroup2DBlockToLinearLayout(shape, *this, getKWidth());
}

//===----------------------------------------------------------------------===//
// Dialect Interface
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,4 +523,119 @@ LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout,
return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx());
}

namespace {

static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
unsigned broadcastDim,
StringAttr inDimName) {
int rank = shape.size();
auto dimNames = standardOutDimNames(ctx, rank);
LinearLayout layout = LinearLayout::empty();

for (auto d : order) {
if (d == broadcastDim) {
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
} else {
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
}
}
return layout;
}

using basisT = std::vector<std::vector<int32_t>>;

// Creates a row major tile layout with register/lane input dimensions according
// to the provided height, width, and threadsPerWarp. The relationship between
// the width and threadsPerWarp determines the packing of rows across lanes:
// - if width == threadsPerWarp:
// block row elements are mapped to registers in row major order, i.e. one
// column per lane
// - if width < threadsPerWarp:
// multiple rows are mapped to the first register to fill the warp, i.e.
// width * rowsPerWarp = threadsPerWarp
// - if width > threadsPerWarp:
// multiple elements of each row are assigned to registers such that
// packedElementsPerLane row values exist in consecutive registers for each
// lane
std::pair<basisT, basisT>
createRegisterLaneBases(const int height, const int width,
const unsigned threadsPerWarp) {
const int packedElementsPerLane =
mlir::ceil<int>(width, static_cast<int>(threadsPerWarp));

basisT laneBases;
for (int i = packedElementsPerLane; i < width; i = i << 1) {
laneBases.push_back({0, i});
}

const int rowsPerWarp =
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size());
// Place subsequent rows into adjacent lanes until all lanes have been filled
for (int i = 1; i < rowsPerWarp; i = i << 1) {
laneBases.push_back({i, 0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the idea is that this describes the columns, and if the mapping is one-to-one (width == threadsPerWarp), then these values are just absent (or err, broadcasted into a single slot)? Is it fair to say in the case of width == threadsPerWarp you have a lower dimensionality here? Or in other words, if I were to set the value anyway, would it just be the one in {0,0}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand the question - if rowsToLaneRatio (now renamed rowsPerWarp) is 1 then we only have one row per set of registers per lane - this means that each lane has one or more columns of elements (depending on if the width of the block is equal to or larger than the number of threads in the warp / number of lanes). Regardless, because each lane is processing a column there is never a situation where the row index is non-zero if the register index is zero - i.e. the first row of registers will always have row index zero.

}

basisT regBases;

// Add packed row-wise elements (width > threadsPerWarp) before adding columns
for (int i = 1; i < packedElementsPerLane; i = i << 1) {
regBases.push_back({0, i});
}

for (int i = 1; i < height / rowsPerWarp; i = i << 1) {
regBases.push_back({i * rowsPerWarp, 0});
}

return std::make_pair(regBases, laneBases);
}

} // namespace

LinearLayout
subgroup2DBlockToLinearLayout(ArrayRef<int64_t> blockShape,
intel::Subgroup2DBlockEncodingAttr layout,
unsigned kWidth) {
auto ctx = layout.getContext();
int rank = blockShape.size();
assert(rank == layout.getRank() && "unexpected block shape rank, layout rank "
"and block shape rank must be equal");
auto dimNames = standardOutDimNames(ctx, rank);
auto loadTileSize = layout.getInstrShape();
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

// Start by creating register/lane bases corresponding to the desired load
// tile size
auto [regBases, laneBases] = createRegisterLaneBases(
loadTileSize[0], loadTileSize[1], layout.getThreadsPerWarp());

LinearLayout::BasesT bases;
bases[kRegister] = regBases;
bases[kLane] = laneBases;
auto ctaLayout = LinearLayout(bases, dimNames);

assert(ctaLayout.getInDimSize(kLane) <= layout.getThreadsPerWarp() &&
"number of lanes should not exceed threads per warp");

// Increasing the block count always increases the inner dimension for the
// register/lane layout regardless of order
ctaLayout *=
LinearLayout::identity1D(layout.getNumBlocks(), kRegister, dimNames[1]);
Comment on lines +623 to +626
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get the inner dimension part. How would the layout be affected in case we have, say, two blocks? I see in tests the block bases are always empty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term Blocks is used twice - here we refer to block count which is a parameter of the block IO operation - https://github.khronos.org/SPIRV-Registry/extensions/INTEL/[SPV_INTEL_2d_block_io](https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview).html#_overview . Block count is essentially used to load consecutive blocks to better utilize the cache lines. Tile blocks are loaded consecutively in row-major order, so we simply multiply by the identity layout on the inner dimension, which effectively "copies" the ctaLayout numBlocks times w/r/t the inner dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok, makes sense now. Here's the correct link btw https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html#_overview.html, those seem to be broken


// Broadcast the layout according to warpsPerCTA, then combine with the
// overall CTALayout and reshape according to the provided blockShape.
auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ true);
auto order = layout.getOrder();
assert(order.size() == 2 && "only rank 2 order supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this essentially restrict the layout rank to be 2 as well? My question should probably be "why does the first assertion not check the layout rank to be 2 in the first place?". Or even is the attribute verification not enough to catch this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - it is essentially being conservative. I don't want to accidentally hit this codepath in any of the dot3d code paths, and it is unclear to me when those dot 3d codepaths are active. For now, I think extra asserts is fine, should help catch issues early in the tests.

unsigned inner = order[0];

ctaLayout *= broadcastedDotOperandLayout(ctx, layout.getWarpsPerCTA(),
warpOrder, inner, kWarp)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
return combineCtaCgaWithShape(ctaLayout, layout.getCTALayout(), blockShape);
}

} // namespace mlir::triton::gpu
14 changes: 14 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,20 @@ struct LoadOpConversion
}
Value elemSizeInBytes = b.i32_val(originalElemBits / 8);

LLVM_DEBUG({
const unsigned numLoads = numRepOuter * numLoadPerOutRepCluster *
numRepInner / numOperandsInnerDimPerLoad;
llvm::dbgs() << "Preparing to dispatch " << numLoads << " loads\n";
llvm::dbgs() << "Outer loads: " << numRepOuter * numLoadPerOutRepCluster
<< " (" << numLoadPerOutRepCluster
<< " per out rep cluster)\n";
llvm::dbgs() << "Inner loads: "
<< numRepInner / numOperandsInnerDimPerLoad << "\n";
llvm::dbgs() << "Load dimension: " << tileHeight << ", "
<< tileWidth * vBlocks << " (" << elemSizeInBits
<< " bits)\n";
});

ValueTable loadVals;
for (int outer = 0; outer < numRepOuter; ++outer) {
for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) {
Expand Down
1 change: 1 addition & 0 deletions unittest/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(TritonGPU)
add_subdirectory(TritonIntelGPU)
10 changes: 10 additions & 0 deletions unittest/Dialect/TritonIntelGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_triton_ut(
NAME LinearLayoutConversionsIntel
SRCS LinearLayoutConversionsTest.cpp
LIBS
TritonGPUIR
TritonGPUTransforms
TritonIntelAnalysis
TritonIntelGPUTransforms
TritonNvidiaGPUTransforms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this coming from the ttg dialect somehow? Or is it just a copy paste error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NVIDIA lib is a dependency for TritonGPUIR and TritonGPUTransforms.

      /usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/LowerLoops.cpp.o: in function `mlir::triton::gpu::(anonymous namespace)::mustLoadToRegisters(mlir::Operation*)':
      /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp:162: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
      /usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp:165: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
      /usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/LowerLoops.cpp.o: in function `llvm::LogicalResult mlir::triton::nvidia_gpu::createTMADesc<mlir::triton::gpu::(anonymous namespace)::OpBuilderForStage>(mlir::Value, mlir::triton::MakeTensorDescOp, mlir::triton::gpu::(anonymous namespace)::OpBuilderForStage&)':
      /home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:77: undefined reference to `mlir::triton::nvidia_gpu::getTMAContigDim(mlir::Attribute, llvm::ArrayRef<long>)'
      /usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:99: undefined reference to `mlir::triton::nvidia_gpu::getTMASwizzleMode(mlir::Operation*, mlir::triton::TensorDescType)'
      /usr/bin/ld: /home/jovyan/intel-xpu-backend-for-triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h:125: undefined reference to `mlir::triton::nvidia_gpu::getTMAElementType(mlir::Operation*, mlir::triton::TensorDescType)'
      /usr/bin/ld: lib/Dialect/TritonGPU/Transforms/CMakeFiles/TritonGPUTransforms.dir/Pipeliner/PipeliningUtility.cpp.o: in function `mlir::triton::getSharedEncoding(mlir::Operation*)':
      /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp:513: undefined reference to `mlir::triton::nvidia_gpu::getEncodingFromDescriptor(mlir::Operation*, mlir::RankedTensorType, mlir::Value)'
      collect2: error: ld returned 1 exit status

)
Loading