@@ -546,30 +546,34 @@ static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
546546
547547using basisT = std::vector<std::vector<int32_t >>;
548548
549- // Creates a row major tile layout with register/lane in dimension according to
550- // the provided height, width, and threadsPerWarp. The relationship between the
551- // width and threadsPerWarp determines the packing of registers in lanes.
552- // * if width == threadsPerWarp, registers are distributed to lanes in row major
553- // order, i.e. one column per lane
554- // * if width < threadsPerWarp, rows are distributed to lanes according to
555- // rowToLaneration, i.e. width * rowToLaneRatio = threadsPerWarp
556- // * if width > threadsPerWarp, row values are packed in lanes such that
557- // packedElementsPerLane row values exist in consecutive registers for each lane
549+ // Creates a row major tile layout with register/lane input dimensions according
550+ // to the provided height, width, and threadsPerWarp. The relationship between
551+ // the width and threadsPerWarp determines the packing of rows across lanes:
552+ // - if width == threadsPerWarp:
553+ // block row elements are mapped to registers in row major order, i.e. one
554+ // column per lane
555+ // - if width < threadsPerWarp:
556+ // multiple rows are mapped to the first register to fill the warp, i.e.
557+ // width * rowsPerWarp = threadsPerWarp
558+ // - if width > threadsPerWarp:
559+ // multiple elements of each row are assigned to registers such that
560+ // packedElementsPerLane row values exist in consecutive registers for each
561+ // lane
558562std::pair<basisT, basisT>
559563createRegisterLaneBases (const int height, const int width,
560564 const unsigned threadsPerWarp) {
561- const unsigned packedElementsPerLane =
562- mlir::ceil<unsigned >(width, threadsPerWarp);
565+ const int packedElementsPerLane =
566+ mlir::ceil<int >(width, static_cast < int >( threadsPerWarp) );
563567
564568 basisT laneBases;
565569 for (int i = packedElementsPerLane; i < width; i = i << 1 ) {
566570 laneBases.push_back ({0 , i});
567571 }
568572
569- const int rowsToLaneRatio =
573+ const int rowsPerWarp =
570574 mlir::ceil<int >(threadsPerWarp, 1 << laneBases.size ());
571575 // Place subsequent rows into adjacent lanes until all lanes have been filled
572- for (int i = 1 ; i < rowsToLaneRatio ; i = i << 1 ) {
576+ for (int i = 1 ; i < rowsPerWarp ; i = i << 1 ) {
573577 laneBases.push_back ({i, 0 });
574578 }
575579
@@ -580,8 +584,8 @@ createRegisterLaneBases(const int height, const int width,
580584 regBases.push_back ({0 , i});
581585 }
582586
583- for (int i = 1 ; i < height / rowsToLaneRatio ; i = i << 1 ) {
584- regBases.push_back ({i * rowsToLaneRatio , 0 });
587+ for (int i = 1 ; i < height / rowsPerWarp ; i = i << 1 ) {
588+ regBases.push_back ({i * rowsPerWarp , 0 });
585589 }
586590
587591 return std::make_pair (regBases, laneBases);
0 commit comments