Skip to content

Commit ff4bfda

Browse files
committed
additional formatting and doc cleanups
1 parent b0658ef commit ff4bfda

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,8 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
300300
- `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor.
301301
- `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution
302302
- `order` : The order within the block, used to determine along which dimension to broadcast.
303-
- `kWidth` : used?
304-
- `numReps` : unused?
305-
- `CTALayout` : ??
303+
- `kWidth` : Currently unused, but keeping because we will likely need it for layout conversions.
304+
- `CTALayout` : Describes how blocks are distributed among work-groups/thread blocks.
306305
}];
307306

308307
let parameters = (

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,8 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
531531
LogicalResult Subgroup2DBlockEncodingAttr::verify(
532532
function_ref<InFlightDiagnostic()> emitError,
533533
ArrayRef<unsigned> warpsPerCTA, CTALayoutAttr CTALayout,
534-
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order, unsigned kWidth,
535-
unsigned threadsPerWarp) {
534+
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order,
535+
unsigned kWidth, unsigned threadsPerWarp) {
536536
if (instrShape.size() != 2) {
537537
return emitError() << "instrShape must be rank 2 but was: "
538538
<< instrShape.size();
@@ -652,7 +652,8 @@ void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const {
652652
maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank());
653653

654654
printer << ", instrShape = [" << getInstrShape()
655-
<< "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder() << "], kWidth=" << getKWidth()
655+
<< "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder()
656+
<< "], kWidth=" << getKWidth()
656657
<< ", threadsPerWarp=" << getThreadsPerWarp() << "}>";
657658
}
658659

unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class LinearLayoutConversionsTest : public ::testing::Test {
3535
CTALayoutAttr::get(
3636
&ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout?
3737
dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()),
38-
instrShape, numBlocks, getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth,
38+
instrShape, numBlocks,
39+
getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth,
3940
dpasLayout.getThreadsPerWarp());
4041
return layout;
4142
}

0 commit comments

Comments
 (0)