Skip to content

Commit 1bcf634

Browse files
committed
fixup comments, uniformly use int in create tile reg/lane bases
1 parent 1ecb35b commit 1bcf634

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -546,30 +546,34 @@ static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
546546

547547
using basisT = std::vector<std::vector<int32_t>>;
548548

549-
// Creates a row major tile layout with register/lane in dimension according to
550-
// the provided height, width, and threadsPerWarp. The relationship between the
551-
// width and threadsPerWarp determines the packing of registers in lanes.
552-
// * if width == threadsPerWarp, registers are distributed to lanes in row major
553-
// order, i.e. one column per lane
554-
// * if width < threadsPerWarp, rows are distributed to lanes according to
555-
// rowToLaneration, i.e. width * rowToLaneRatio = threadsPerWarp
556-
// * if width > threadsPerWarp, row values are packed in lanes such that
557-
// packedElementsPerLane row values exist in consecutive registers for each lane
549+
// Creates a row major tile layout with register/lane input dimensions according
550+
// to the provided height, width, and threadsPerWarp. The relationship between
551+
// the width and threadsPerWarp determines the packing of rows across lanes:
552+
// - if width == threadsPerWarp:
553+
// block row elements are mapped to registers in row major order, i.e. one
554+
// column per lane
555+
// - if width < threadsPerWarp:
556+
// multiple rows are mapped to the first register to fill the warp, i.e.
557+
// width * rowsPerWarp = threadsPerWarp
558+
// - if width > threadsPerWarp:
559+
// multiple elements of each row are assigned to registers such that
560+
// packedElementsPerLane row values exist in consecutive registers for each
561+
// lane
558562
std::pair<basisT, basisT>
559563
createRegisterLaneBases(const int height, const int width,
560564
const unsigned threadsPerWarp) {
561-
const unsigned packedElementsPerLane =
562-
mlir::ceil<unsigned>(width, threadsPerWarp);
565+
const int packedElementsPerLane =
566+
mlir::ceil<int>(width, static_cast<int>(threadsPerWarp));
563567

564568
basisT laneBases;
565569
for (int i = packedElementsPerLane; i < width; i = i << 1) {
566570
laneBases.push_back({0, i});
567571
}
568572

569-
const int rowsToLaneRatio =
573+
const int rowsPerWarp =
570574
mlir::ceil<int>(threadsPerWarp, 1 << laneBases.size());
571575
// Place subsequent rows into adjacent lanes until all lanes have been filled
572-
for (int i = 1; i < rowsToLaneRatio; i = i << 1) {
576+
for (int i = 1; i < rowsPerWarp; i = i << 1) {
573577
laneBases.push_back({i, 0});
574578
}
575579

@@ -580,8 +584,8 @@ createRegisterLaneBases(const int height, const int width,
580584
regBases.push_back({0, i});
581585
}
582586

583-
for (int i = 1; i < height / rowsToLaneRatio; i = i << 1) {
584-
regBases.push_back({i * rowsToLaneRatio, 0});
587+
for (int i = 1; i < height / rowsPerWarp; i = i << 1) {
588+
regBases.push_back({i * rowsPerWarp, 0});
585589
}
586590

587591
return std::make_pair(regBases, laneBases);

0 commit comments

Comments
 (0)