Skip to content

Commit 15d16d3

Browse files
Merge OpenAI Triton commit e1162ee (#3913)
This PR change the Triton base from a0cc214 to e1162ee (Apr 10). Pass rate: 88.42%->88.52%
2 parents 8776dd9 + ed497eb commit 15d16d3

File tree

27 files changed

+1564
-24
lines changed

27 files changed

+1564
-24
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
101101
mlir::registerTritonAMDGPUConvertToBufferOps();
102102
mlir::registerTritonAMDGPUInThreadTranspose();
103103
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
104+
mlir::registerTritonAMDGPUUpdateAsyncWaitCount();
104105
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
105106
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
106107
mlir::registerTritonAMDFoldTrueCmpI();

include/triton/Analysis/Alias.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class SharedMemoryAliasAnalysis
8989
visitOperation(Operation *op,
9090
ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
9191
ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
92+
93+
void visitNonControlFlowArguments(
94+
Operation *op, const RegionSuccessor &successor,
95+
ArrayRef<dataflow::Lattice<AliasInfo> *> argLattices,
96+
unsigned firstIndex) override;
9297
};
9398

9499
} // namespace mlir

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,23 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
218218
"mlir::arith::ArithDialect"];
219219
}
220220

221+
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
222+
let summary = "prefetch for wgmma mixed precision";
223+
224+
let description = [{
225+
This pass attempts to prefetch from shared memory for mixed-precision
226+
wgmma when operand A is in the shared memory and needs to be loaded
227+
to the local registers.
228+
}];
229+
230+
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
231+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
232+
"mlir::scf::SCFDialect",
233+
"mlir::arith::ArithDialect"];
234+
}
235+
236+
237+
221238
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
222239
let summary = "accelerate matmul";
223240

lib/Analysis/Alias.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,30 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
5858
return success();
5959
}
6060

61+
void SharedMemoryAliasAnalysis::visitNonControlFlowArguments(
62+
Operation *op, const RegionSuccessor &successor,
63+
ArrayRef<dataflow::Lattice<AliasInfo> *> argLattices, unsigned firstIndex) {
64+
auto wsOp = dyn_cast<triton::gpu::WarpSpecializePartitionsOp>(op);
65+
if (!wsOp) {
66+
setAllToEntryStates(argLattices.take_front(firstIndex));
67+
setAllToEntryStates(argLattices.drop_front(
68+
firstIndex + successor.getSuccessorInputs().size()));
69+
return;
70+
}
71+
72+
// Propagate aliases from the parent operation's operands to the block
73+
// arguments.
74+
assert(!successor.isParent());
75+
ProgramPoint *point = getProgramPointAfter(wsOp);
76+
77+
for (auto [capture, argLattice] :
78+
llvm::zip(wsOp.getParentOp().getExplicitCaptures(), argLattices)) {
79+
propagateIfChanged(
80+
argLattice,
81+
argLattice->join(getLatticeElementFor(point, capture)->getValue()));
82+
}
83+
}
84+
6185
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
6286
// TODO: implement
6387
return AliasResult::MayAlias;

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ class AllocationAnalysis {
332332
solver->load<SharedMemoryAliasAnalysis>();
333333
// Run the analysis rooted at every isolated from above operation, including
334334
// the top-level function but also any nested regions.
335-
operation->walk([&](Operation *op) {
335+
operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
336336
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
337337
failed(solver->initializeAndRun(op))) {
338338
// TODO: return error instead of bailing out..

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_triton_library(TritonGPUTransforms
2626
Pipeliner/PipeliningUtility.cpp
2727
Pipeliner/Schedule.cpp
2828
Prefetch.cpp
29+
WGMMAPrefetch.cpp
2930
RemoveLayoutConversions.cpp
3031
ReorderInstructions.cpp
3132
CoalesceAsyncCopy.cpp

0 commit comments

Comments
 (0)