Skip to content

Commit 52d5c1f

Browse files
committed
rename transform to reify-result-shapes
1 parent 465c660 commit 52d5c1f

File tree

7 files changed

+221
-150
lines changed

7 files changed

+221
-150
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,40 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
182182
];
183183
}
184184

185-
def InferStaticShapesPass : Pass<"infer-static-shapes"> {
186-
let summary = "Resolve memref.dim of result values";
185+
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186+
let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
187187
let description = [{
188-
The pass resolves memref.dim of result of operations that
189-
implement the `InferShapedTypeOpInterface` or
190-
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
191-
operands.
188+
This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189+
operation with ranked `memref` and `tensor` results. Replacing the
190+
operations with their reified versions, and inserting casts when results
191+
shapes are updated.
192+
193+
Example:
194+
```mlir
195+
#map = affine_map<(d0) -> (-d0 + 256)>
196+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
197+
%0 = affine.apply #map(%arg1)
198+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
199+
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200+
^bb0(%arg3: index, %arg4: index, %arg5: index):
201+
tensor.yield %arg0 : f32
202+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
203+
return %padded : tensor<1x?x64xf32>
204+
}
205+
206+
// mlir-opt --reify-result-shapes
207+
#map = affine_map<()[s0] -> (-s0 + 256)>
208+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
209+
%0 = affine.apply #map()[%arg1]
210+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
211+
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212+
^bb0(%arg3: index, %arg4: index, %arg5: index):
213+
tensor.yield %arg0 : f32
214+
} : tensor<1x?x64xf32> to tensor<1x256x64xf32>
215+
%cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
216+
return %cast : tensor<1x?x64xf32>
217+
}
218+
```
192219
}];
193220
let dependentDialects = [
194221
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class RewritePatternSet;
2323
class RewriterBase;
2424
class Value;
2525
class ValueRange;
26+
class ReifyRankedShapedTypeOpInterface;
2627

2728
namespace arith {
2829
class WideIntEmulationConverter;
@@ -213,6 +214,17 @@ memref::AllocaOp allocToAlloca(
213214
RewriterBase &rewriter, memref::AllocOp alloc,
214215
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
215216

217+
/// Reifies the results of `op`, potentially replacing `op` with a reified
218+
/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
219+
/// otherwise it always succeeds. Users of this transform should always expect
220+
/// it to modify the IR, even when it fails. If any of the result types changes,
221+
/// the transform will insert cast operations to the old type to keep the IR
222+
/// consistent.
223+
///
224+
/// Note: This transform only works on ranked `memref` or `tensor` results,
225+
/// other types are ignored.
226+
LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
227+
ReifyRankedShapedTypeOpInterface op);
216228
} // namespace memref
217229
} // namespace mlir
218230

mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
1313
IndependenceTransforms.cpp
1414
MultiBuffer.cpp
1515
NormalizeMemRefs.cpp
16+
ReifyResultShapes.cpp
1617
ResolveShapedTypeResultDims.cpp
1718
RuntimeOpVerification.cpp
1819

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
10+
// operations with ranked `memref` and `tensor` results.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Interfaces/InferTypeOpInterface.h"
21+
#include "llvm/Support/InterleavedRange.h"
22+
23+
#define DEBUG_TYPE "reify-result-shapes"
24+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
25+
26+
namespace mlir {
27+
namespace memref {
28+
#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
29+
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
30+
} // namespace memref
31+
} // namespace mlir
32+
33+
using namespace mlir;
34+
35+
LogicalResult
36+
mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
37+
ReifyRankedShapedTypeOpInterface op) {
38+
LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
39+
// Get the reified out shapes.
40+
ReifiedRankedShapedTypeDims reifiedResultShapes;
41+
if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
42+
reifiedResultShapes.empty()) {
43+
return op.emitError() << "failed to get the reified shapes";
44+
}
45+
46+
bool modified = false;
47+
// Compute the new output types.
48+
SmallVector<Type> outTypes;
49+
for (const auto &[oldTy, reifiedShape] :
50+
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
51+
// Skip if it's not a memref or tensor type.
52+
if (!isa<RankedTensorType, MemRefType>(oldTy)) {
53+
outTypes.push_back(oldTy);
54+
continue;
55+
}
56+
57+
ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
58+
59+
SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
60+
for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
61+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
62+
// If the reified dim is dynamic set it appropriately.
63+
if (!maybeCst.has_value()) {
64+
dim = ShapedType::kDynamic;
65+
continue;
66+
}
67+
// Set the static dim.
68+
dim = *maybeCst;
69+
}
70+
71+
// If the shape didn't change continue.
72+
if (shape == shapedTy.getShape()) {
73+
outTypes.push_back(oldTy);
74+
continue;
75+
}
76+
modified = true;
77+
outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
78+
}
79+
80+
// Return if we don't need to update.
81+
if (!modified) {
82+
LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
83+
return success();
84+
}
85+
86+
LLVM_DEBUG({
87+
DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
88+
<< " \n";
89+
DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
90+
});
91+
92+
// We now have outTypes that need to be turned to cast ops.
93+
Location loc = op->getLoc();
94+
SmallVector<Value> newResults;
95+
Operation *newOp = rewriter.clone(*op);
96+
for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
97+
OpResult newRes = newOp->getResult(oldRes.getResultNumber());
98+
Type oldTy = oldRes.getType();
99+
// Continue if the type remained invariant or is not shaped.
100+
if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
101+
newResults.push_back(newRes);
102+
continue;
103+
}
104+
105+
// Update the type.
106+
newRes.setType(reifiedTy);
107+
if (isa<RankedTensorType>(reifiedTy)) {
108+
newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
109+
} else {
110+
assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
111+
newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
112+
}
113+
}
114+
115+
LLVM_DEBUG({
116+
DBGS() << "- reified results " << llvm::interleaved_array(newResults)
117+
<< "\n";
118+
});
119+
rewriter.replaceOp(op, newResults);
120+
return success();
121+
}
122+
123+
//===----------------------------------------------------------------------===//
124+
// Pass registration
125+
//===----------------------------------------------------------------------===//
126+
127+
namespace {
128+
struct ReifyResultShapesPass final
129+
: public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
130+
void runOnOperation() override;
131+
};
132+
} // namespace
133+
134+
void ReifyResultShapesPass::runOnOperation() {
135+
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
136+
getOperation()->walk(
137+
[&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
138+
IRRewriter rewriter(&getContext());
139+
for (ReifyRankedShapedTypeOpInterface op : ops) {
140+
rewriter.setInsertionPoint(op);
141+
if (failed(memref::reifyOpResultShapes(rewriter, op)))
142+
return signalPassFailure();
143+
}
144+
}

mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

Lines changed: 0 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,13 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/Tensor/IR/Tensor.h"
23-
#include "mlir/IR/BuiltinTypeInterfaces.h"
24-
#include "mlir/IR/BuiltinTypes.h"
25-
#include "mlir/IR/Value.h"
2623
#include "mlir/Interfaces/InferTypeOpInterface.h"
2724
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28-
#include "llvm/Support/ErrorHandling.h"
29-
#include "llvm/Support/InterleavedRange.h"
30-
31-
#define DEBUG_TYPE "resolve-shaped-type"
32-
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
3325

3426
namespace mlir {
3527
namespace memref {
3628
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
3729
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
38-
#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
3930
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
4031
} // namespace memref
4132
} // namespace mlir
@@ -114,99 +105,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
114105
}
115106
};
116107

117-
struct ReifyToInferStaticShapePattern
118-
: public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
119-
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
120-
121-
LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
122-
PatternRewriter &rewriter) const override {
123-
LLVM_DEBUG(
124-
{ DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
125-
126-
bool rewriteToMoreStatic = false;
127-
ReifiedRankedShapedTypeDims reifiedResultShapes;
128-
if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
129-
reifiedResultShapes.empty()) {
130-
LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
131-
return failure();
132-
}
133-
134-
SmallVector<Type> newTypes;
135-
for (auto [t, reifiedShape] :
136-
llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
137-
ShapedType st = dyn_cast<ShapedType>(t);
138-
if (!st)
139-
continue;
140-
141-
SmallVector<int64_t> newShape;
142-
for (const auto &[s, ofr] :
143-
llvm::zip_equal(st.getShape(), reifiedShape)) {
144-
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
145-
// Reification does not add static information, just use existing shape.
146-
if (!maybeCst.has_value()) {
147-
newShape.push_back(s);
148-
continue;
149-
}
150-
int64_t cst = *maybeCst;
151-
assert((ShapedType::isDynamic(s) || s == cst) &&
152-
"constants must agree!");
153-
newShape.push_back(cst);
154-
}
155-
156-
if (newShape == st.getShape()) {
157-
newTypes.push_back(t);
158-
continue;
159-
}
160-
161-
rewriteToMoreStatic = true;
162-
Type newType = st.cloneWith(newShape, st.getElementType());
163-
newTypes.push_back(newType);
164-
}
165-
166-
LLVM_DEBUG({
167-
DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
168-
<< " \n";
169-
DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
170-
});
171-
if (!rewriteToMoreStatic) {
172-
LLVM_DEBUG({ DBGS() << "not more static\n"; });
173-
return failure();
174-
}
175-
176-
// We now have newTypes that need to be turned to tensor::CastOp.
177-
Location loc = op->getLoc();
178-
SmallVector<Value> newResults;
179-
Operation *newOp = rewriter.clone(*op);
180-
for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
181-
Type ot = oldVal.getType();
182-
OpResult newResult = newOp->getResult(oldVal.getResultNumber());
183-
if (ot == nt) {
184-
newResults.push_back(newResult);
185-
continue;
186-
}
187-
newResult.setType(nt);
188-
if (isa<RankedTensorType>(nt)) {
189-
newResults.push_back(
190-
rewriter.create<tensor::CastOp>(loc, ot, newResult));
191-
} else if (isa<MemRefType>(nt)) {
192-
newResults.push_back(
193-
rewriter.create<memref::CastOp>(loc, ot, newResult));
194-
} else {
195-
llvm_unreachable("expected RankedTensorType or MemRefType");
196-
}
197-
}
198-
199-
LLVM_DEBUG({
200-
op->getParentOp()->dump();
201-
DBGS() << "replace op " << *op << "\n";
202-
DBGS() << "with newResults " << llvm::interleaved_array(newResults)
203-
<< "\n\n\n\n";
204-
});
205-
rewriter.replaceAllOpUsesWith(op, newResults);
206-
return success();
207-
}
208-
};
209-
210108
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
211109
///
212110
/// ```
@@ -277,11 +175,6 @@ struct ResolveShapedTypeResultDimsPass final
277175
void runOnOperation() override;
278176
};
279177

280-
struct InferStaticShapesPass final
281-
: public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
282-
void runOnOperation() override;
283-
};
284-
285178
} // namespace
286179

287180
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -299,11 +192,6 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
299192
patterns.getContext());
300193
}
301194

302-
void memref::populateReifyToInferStaticShapePatterns(
303-
RewritePatternSet &patterns) {
304-
patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
305-
}
306-
307195
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
308196
RewritePatternSet patterns(&getContext());
309197
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -318,17 +206,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
318206
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
319207
return signalPassFailure();
320208
}
321-
322-
void InferStaticShapesPass::runOnOperation() {
323-
RewritePatternSet patterns(&getContext());
324-
patterns.add<ReifyToInferStaticShapePattern>(&getContext());
325-
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
326-
327-
SmallVector<Operation *> opsToSimplify;
328-
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
329-
opsToSimplify.push_back(op);
330-
});
331-
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
332-
GreedyRewriteConfig().setStrictness(
333-
GreedyRewriteStrictness::ExistingOps));
334-
}

0 commit comments

Comments
 (0)