@@ -68,13 +68,11 @@ class TMALoadLowering : public OpRewritePattern<DescriptorLoadOp> {
6868 LogicalResult matchAndRewrite (DescriptorLoadOp op,
6969 PatternRewriter &rewriter) const override {
7070 auto loc = op.getLoc ();
71- auto createLoad = [&](Value tmaPtr , Value barrierAlloc, Value alloc,
71+ auto createLoad = [&](Value desc , Value barrierAlloc, Value alloc,
7272 Value pred) {
73- auto indices = translateTMAIndices (
74- rewriter, op.getLoc (),
75- op.getDesc ().getType ().getBlockType ().getEncoding (), op.getIndices ());
7673 triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create (
77- rewriter, op.getLoc (), tmaPtr, indices, barrierAlloc, alloc, pred);
74+ rewriter, op.getLoc (), desc, op.getIndices (), barrierAlloc, alloc,
75+ pred);
7876 };
7977 lowerTMALoad (op, op.getType (), op.getDesc (), createLoad, rewriter);
8078 return success ();
@@ -86,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern<DescriptorGatherOp> {
8684
8785 LogicalResult matchAndRewrite (DescriptorGatherOp op,
8886 PatternRewriter &rewriter) const override {
89- auto createLoad = [&](Value tmaPtr , Value barrierAlloc, Value alloc,
87+ auto createLoad = [&](Value desc , Value barrierAlloc, Value alloc,
9088 Value pred) {
9189 triton::nvidia_gpu::AsyncTMAGatherOp::create (
92- rewriter, op.getLoc (), tmaPtr , op.getXOffsets (), op.getYOffset (),
90+ rewriter, op.getLoc (), desc , op.getXOffsets (), op.getYOffset (),
9391 barrierAlloc, alloc, pred);
9492 };
9593 lowerTMALoad (op, op.getType (), op.getDesc (), createLoad, rewriter);
@@ -122,12 +120,9 @@ struct TMAStoreLowering : public OpRewritePattern<DescriptorStoreOp> {
122120
123121 LogicalResult matchAndRewrite (DescriptorStoreOp op,
124122 PatternRewriter &rewriter) const override {
125- auto createStore = [&](Value tmaPtr, Value alloc) {
126- auto indices = translateTMAIndices (
127- rewriter, op.getLoc (),
128- op.getDesc ().getType ().getBlockType ().getEncoding (), op.getIndices ());
123+ auto createStore = [&](Value desc, Value alloc) {
129124 triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create (
130- rewriter, op.getLoc (), tmaPtr, indices , alloc);
125+ rewriter, op.getLoc (), desc, op. getIndices () , alloc);
131126 };
132127 lowerTMAStore (op, op.getSrc (), op.getDesc (), createStore, rewriter);
133128 return success ();
@@ -139,12 +134,9 @@ struct TMAReduceLowering : public OpRewritePattern<DescriptorReduceOp> {
139134
140135 LogicalResult matchAndRewrite (DescriptorReduceOp op,
141136 PatternRewriter &rewriter) const override {
142- auto createStore = [&](Value tmaPtr, Value alloc) {
143- auto indices = translateTMAIndices (
144- rewriter, op.getLoc (),
145- op.getDesc ().getType ().getBlockType ().getEncoding (), op.getIndices ());
137+ auto createStore = [&](Value desc, Value alloc) {
146138 triton::nvidia_gpu::AsyncTMAReduceOp::create (
147- rewriter, op.getLoc (), op.getKind (), tmaPtr, indices , alloc);
139+ rewriter, op.getLoc (), op.getKind (), desc, op. getIndices () , alloc);
148140 };
149141 lowerTMAStore (op, op.getSrc (), op.getDesc (), createStore, rewriter);
150142 return success ();
@@ -156,9 +148,9 @@ struct TMAScatterLowering : public OpRewritePattern<DescriptorScatterOp> {
156148
157149 LogicalResult matchAndRewrite (DescriptorScatterOp op,
158150 PatternRewriter &rewriter) const override {
159- auto createStore = [&](Value tmaPtr , Value alloc) {
160- triton::nvidia_gpu::AsyncTMAScatterOp::create (rewriter, op.getLoc (),
161- tmaPtr, op.getXOffsets (),
151+ auto createStore = [&](Value desc , Value alloc) {
152+ triton::nvidia_gpu::AsyncTMAScatterOp::create (rewriter, op.getLoc (), desc,
153+ op.getXOffsets (),
162154 op.getYOffset (), alloc);
163155 };
164156 lowerTMAStore (op, op.getSrc (), op.getDesc (), createStore, rewriter);
0 commit comments