20
20
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
21
21
#include " mlir/Dialect/SCF/IR/SCF.h"
22
22
#include " mlir/Dialect/Tensor/IR/Tensor.h"
23
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
24
+ #include " mlir/IR/BuiltinTypes.h"
25
+ #include " mlir/IR/Value.h"
23
26
#include " mlir/Interfaces/InferTypeOpInterface.h"
24
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
+ #include " llvm/Support/ErrorHandling.h"
25
29
26
30
namespace mlir {
27
31
namespace memref {
28
32
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29
33
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
34
+ #define GEN_PASS_DEF_INFERSTATICSHAPESPASS
30
35
#include " mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31
36
} // namespace memref
32
37
} // namespace mlir
@@ -105,6 +110,83 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
105
110
}
106
111
};
107
112
113
+ struct ReifyToInferStaticShapePattern
114
+ : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
115
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
116
+
117
+ LogicalResult matchAndRewrite (ReifyRankedShapedTypeOpInterface op,
118
+ PatternRewriter &rewriter) const override {
119
+
120
+ bool rewriteToMoreStatic = false ;
121
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
122
+ if (failed (reifyResultShapes (rewriter, op, reifiedResultShapes)) ||
123
+ reifiedResultShapes.empty ())
124
+ return failure ();
125
+
126
+ SmallVector<Type> newTypes;
127
+ for (auto [t, reifiedShape] :
128
+ llvm::zip (op->getResultTypes (), reifiedResultShapes)) {
129
+ ShapedType st = dyn_cast<ShapedType>(t);
130
+ if (!st)
131
+ continue ;
132
+
133
+ SmallVector<int64_t > newShape;
134
+ for (const auto &[s, ofr] :
135
+ llvm::zip_equal (st.getShape (), reifiedShape)) {
136
+ std::optional<int64_t > maybeCst = getConstantIntValue (ofr);
137
+ // Reification does not add static information, just use existing shape.
138
+ if (!maybeCst.has_value ()) {
139
+ newShape.push_back (s);
140
+ continue ;
141
+ }
142
+ int64_t cst = *maybeCst;
143
+ assert ((ShapedType::isDynamic (s) || s == cst) &&
144
+ " constants must agree!" );
145
+ newShape.push_back (cst);
146
+ }
147
+
148
+ if (newShape == st.getShape ()) {
149
+ newTypes.push_back (t);
150
+ continue ;
151
+ }
152
+
153
+ rewriteToMoreStatic = true ;
154
+ Type oldType = st;
155
+ Type newType = st.cloneWith (newShape, st.getElementType ());
156
+ newTypes.push_back (newType);
157
+ }
158
+
159
+ if (!rewriteToMoreStatic)
160
+ return failure ();
161
+
162
+ // We now have newTypes that need to be turned to tensor::CastOp.
163
+ Location loc = op->getLoc ();
164
+ SmallVector<Value> newResults;
165
+ Operation *newOp = rewriter.clone (*op);
166
+ for (auto [nt, oldVal] : llvm::zip (newTypes, op->getResults ())) {
167
+ Type ot = oldVal.getType ();
168
+ OpResult newResult = newOp->getResult (oldVal.getResultNumber ());
169
+ if (ot == nt) {
170
+ newResults.push_back (newResult);
171
+ continue ;
172
+ }
173
+ newResult.setType (nt);
174
+ if (isa<RankedTensorType>(nt)) {
175
+ newResults.push_back (
176
+ rewriter.create <tensor::CastOp>(loc, nt, newResult));
177
+ } else if (isa<MemRefType>(nt)) {
178
+ newResults.push_back (
179
+ rewriter.create <memref::CastOp>(loc, nt, newResult));
180
+ } else {
181
+ llvm_unreachable (" expected RankedTensorType or MemRefType" );
182
+ }
183
+ }
184
+
185
+ rewriter.replaceOp (op, newResults);
186
+ return success ();
187
+ }
188
+ };
189
+
108
190
// / Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109
191
// /
110
192
// / ```
@@ -175,6 +257,11 @@ struct ResolveShapedTypeResultDimsPass final
175
257
void runOnOperation () override ;
176
258
};
177
259
260
+ struct InferStaticShapesPass final
261
+ : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
262
+ void runOnOperation () override ;
263
+ };
264
+
178
265
} // namespace
179
266
180
267
void memref::populateResolveRankedShapedTypeResultDimsPatterns (
@@ -192,6 +279,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
192
279
patterns.getContext ());
193
280
}
194
281
282
+ void memref::populateReifyToInferStaticShapePatterns (
283
+ RewritePatternSet &patterns) {
284
+ patterns.add <ReifyToInferStaticShapePattern>(patterns.getContext ());
285
+ }
286
+
195
287
void ResolveRankedShapeTypeResultDimsPass::runOnOperation () {
196
288
RewritePatternSet patterns (&getContext ());
197
289
memref::populateResolveRankedShapedTypeResultDimsPatterns (patterns);
@@ -206,3 +298,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
206
298
if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
207
299
return signalPassFailure ();
208
300
}
301
+
302
+ void InferStaticShapesPass::runOnOperation () {
303
+ RewritePatternSet patterns (&getContext ());
304
+
305
+ SmallVector<Operation *> opsToSimplify;
306
+ FrozenRewritePatternSet frozenPatterns (std::move (patterns));
307
+ getOperation ()->walk ([&](ReifyRankedShapedTypeOpInterface op) {
308
+ opsToSimplify.push_back (op);
309
+ });
310
+ (void )applyOpPatternsGreedily (
311
+ opsToSimplify, frozenPatterns,
312
+ GreedyRewriteConfig ().setStrictness (
313
+ GreedyRewriteStrictness::ExistingAndNewOps));
314
+ }
0 commit comments