Skip to content

Commit 475e919

Browse files
[mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShapedTypeOpInterface
1 parent d83457e commit 475e919

File tree

4 files changed

+140
-0
lines changed

4 files changed

+140
-0
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
182182
];
183183
}
184184

185+
def InferStaticShapesPass : Pass<"infer-static-shapes"> {
186+
let summary = "Resolve memref.dim of result values";
187+
let description = [{
188+
The pass resolves memref.dim of result of operations that
189+
implement the `InferShapedTypeOpInterface` or
190+
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
191+
operands.
192+
}];
193+
let dependentDialects = [
194+
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
195+
];
196+
}
197+
185198
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
186199
let summary = "Expand memref operations into easier to analyze constructs";
187200
let description = [{

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
5757
/// terms of shapes of its input operands.
5858
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
5959

60+
/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
61+
/// shapes more static.
62+
void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
63+
6064
/// Appends patterns for expanding memref operations that modify the metadata
6165
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
6266
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/Tensor/IR/Tensor.h"
23+
#include "mlir/IR/BuiltinTypeInterfaces.h"
24+
#include "mlir/IR/BuiltinTypes.h"
25+
#include "mlir/IR/Value.h"
2326
#include "mlir/Interfaces/InferTypeOpInterface.h"
2427
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28+
#include "llvm/Support/ErrorHandling.h"
2529

2630
namespace mlir {
2731
namespace memref {
2832
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
2933
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
34+
#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
3035
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
3136
} // namespace memref
3237
} // namespace mlir
@@ -105,6 +110,83 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
105110
}
106111
};
107112

113+
struct ReifyToInferStaticShapePattern
114+
: public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
115+
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
116+
117+
LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
118+
PatternRewriter &rewriter) const override {
119+
120+
bool rewriteToMoreStatic = false;
121+
ReifiedRankedShapedTypeDims reifiedResultShapes;
122+
if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
123+
reifiedResultShapes.empty())
124+
return failure();
125+
126+
SmallVector<Type> newTypes;
127+
for (auto [t, reifiedShape] :
128+
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
129+
ShapedType st = dyn_cast<ShapedType>(t);
130+
if (!st)
131+
continue;
132+
133+
SmallVector<int64_t> newShape;
134+
for (const auto &[s, ofr] :
135+
llvm::zip_equal(st.getShape(), reifiedShape)) {
136+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
137+
// Reification does not add static information, just use existing shape.
138+
if (!maybeCst.has_value()) {
139+
newShape.push_back(s);
140+
continue;
141+
}
142+
int64_t cst = *maybeCst;
143+
assert((ShapedType::isDynamic(s) || s == cst) &&
144+
"constants must agree!");
145+
newShape.push_back(cst);
146+
}
147+
148+
if (newShape == st.getShape()) {
149+
newTypes.push_back(t);
150+
continue;
151+
}
152+
153+
rewriteToMoreStatic = true;
154+
Type oldType = st;
155+
Type newType = st.cloneWith(newShape, st.getElementType());
156+
newTypes.push_back(newType);
157+
}
158+
159+
if (!rewriteToMoreStatic)
160+
return failure();
161+
162+
// We now have newTypes that need to be turned to tensor::CastOp.
163+
Location loc = op->getLoc();
164+
SmallVector<Value> newResults;
165+
Operation *newOp = rewriter.clone(*op);
166+
for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
167+
Type ot = oldVal.getType();
168+
OpResult newResult = newOp->getResult(oldVal.getResultNumber());
169+
if (ot == nt) {
170+
newResults.push_back(newResult);
171+
continue;
172+
}
173+
newResult.setType(nt);
174+
if (isa<RankedTensorType>(nt)) {
175+
newResults.push_back(
176+
rewriter.create<tensor::CastOp>(loc, nt, newResult));
177+
} else if (isa<MemRefType>(nt)) {
178+
newResults.push_back(
179+
rewriter.create<memref::CastOp>(loc, nt, newResult));
180+
} else {
181+
llvm_unreachable("expected RankedTensorType or MemRefType");
182+
}
183+
}
184+
185+
rewriter.replaceOp(op, newResults);
186+
return success();
187+
}
188+
};
189+
108190
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109191
///
110192
/// ```
@@ -175,6 +257,11 @@ struct ResolveShapedTypeResultDimsPass final
175257
void runOnOperation() override;
176258
};
177259

260+
struct InferStaticShapesPass final
261+
: public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
262+
void runOnOperation() override;
263+
};
264+
178265
} // namespace
179266

180267
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +279,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
192279
patterns.getContext());
193280
}
194281

282+
void memref::populateReifyToInferStaticShapePatterns(
283+
RewritePatternSet &patterns) {
284+
patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
285+
}
286+
195287
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196288
RewritePatternSet patterns(&getContext());
197289
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +298,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
206298
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
207299
return signalPassFailure();
208300
}
301+
302+
void InferStaticShapesPass::runOnOperation() {
303+
RewritePatternSet patterns(&getContext());
304+
305+
SmallVector<Operation *> opsToSimplify;
306+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
307+
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
308+
opsToSimplify.push_back(op);
309+
});
310+
(void)applyOpPatternsGreedily(
311+
opsToSimplify, frozenPatterns,
312+
GreedyRewriteConfig().setStrictness(
313+
GreedyRewriteStrictness::ExistingAndNewOps));
314+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @pad_reification
4+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
5+
-> tensor<1x?x64xf32> {
6+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
7+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
8+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
9+
10+
// CHECK: tensor.pad
11+
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
12+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
13+
^bb0(%a: index, %b: index, %c: index):
14+
tensor.yield %cst : f32
15+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
16+
17+
return %padded : tensor<1x?x64xf32>

0 commit comments

Comments
 (0)