@@ -546,30 +546,34 @@ static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
546
546
547
547
using basisT = std::vector<std::vector<int32_t >>;
548
548
549
- // Creates a row major tile layout with register/lane in dimension according to
550
- // the provided height, width, and threadsPerWarp. The relationship between the
551
- // width and threadsPerWarp determines the packing of registers in lanes.
552
- // * if width == threadsPerWarp, registers are distributed to lanes in row major
553
- // order, i.e. one column per lane
554
- // * if width < threadsPerWarp, rows are distributed to lanes according to
555
- // rowToLaneration, i.e. width * rowToLaneRatio = threadsPerWarp
556
- // * if width > threadsPerWarp, row values are packed in lanes such that
557
- // packedElementsPerLane row values exist in consecutive registers for each lane
549
+ // Creates a row major tile layout with register/lane input dimensions according
550
+ // to the provided height, width, and threadsPerWarp. The relationship between
551
+ // the width and threadsPerWarp determines the packing of rows across lanes:
552
+ // - if width == threadsPerWarp:
553
+ // block row elements are mapped to registers in row major order, i.e. one
554
+ // column per lane
555
+ // - if width < threadsPerWarp:
556
+ // multiple rows are mapped to the first register to fill the warp, i.e.
557
+ // width * rowsPerWarp = threadsPerWarp
558
+ // - if width > threadsPerWarp:
559
+ // multiple elements of each row are assigned to registers such that
560
+ // packedElementsPerLane row values exist in consecutive registers for each
561
+ // lane
558
562
std::pair<basisT, basisT>
559
563
createRegisterLaneBases (const int height, const int width,
560
564
const unsigned threadsPerWarp) {
561
- const unsigned packedElementsPerLane =
562
- mlir::ceil <unsigned >(width, threadsPerWarp);
565
+ const int packedElementsPerLane =
566
+ mlir::ceil <int >(width, static_cast < int >( threadsPerWarp) );
563
567
564
568
basisT laneBases;
565
569
for (int i = packedElementsPerLane; i < width; i = i << 1 ) {
566
570
laneBases.push_back ({0 , i});
567
571
}
568
572
569
- const int rowsToLaneRatio =
573
+ const int rowsPerWarp =
570
574
mlir::ceil <int >(threadsPerWarp, 1 << laneBases.size ());
571
575
// Place subsequent rows into adjacent lanes until all lanes have been filled
572
- for (int i = 1 ; i < rowsToLaneRatio ; i = i << 1 ) {
576
+ for (int i = 1 ; i < rowsPerWarp ; i = i << 1 ) {
573
577
laneBases.push_back ({i, 0 });
574
578
}
575
579
@@ -580,8 +584,8 @@ createRegisterLaneBases(const int height, const int width,
580
584
regBases.push_back ({0 , i});
581
585
}
582
586
583
- for (int i = 1 ; i < height / rowsToLaneRatio ; i = i << 1 ) {
584
- regBases.push_back ({i * rowsToLaneRatio , 0 });
587
+ for (int i = 1 ; i < height / rowsPerWarp ; i = i << 1 ) {
588
+ regBases.push_back ({i * rowsPerWarp , 0 });
585
589
}
586
590
587
591
return std::make_pair (regBases, laneBases);
0 commit comments