Skip to content

Commit 6b7217a

Browse files
Review comments
1 parent 08a6823 commit 6b7217a

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,6 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
209209
memref::AllocaOp allocToAlloca(
210210
RewriterBase &rewriter, memref::AllocOp alloc,
211211
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
212-
213-
/// Reifies the results of `op`, potentially replacing `op` with a reified
214-
/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
215-
/// otherwise it always succeeds. Users of this transform should always expect
216-
/// it to modify the IR, even when it fails. If any of the result types changes,
217-
/// the transform will insert cast operations to the old type to keep the IR
218-
/// consistent.
219-
///
220-
/// Note: This transform only works on ranked `memref` or `tensor` results,
221-
/// other types are ignored.
222-
LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
223-
ReifyRankedShapedTypeOpInterface op);
224212
} // namespace memref
225213
} // namespace mlir
226214

mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,14 @@ namespace memref {
3333

3434
using namespace mlir;
3535

36-
LogicalResult
37-
mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
38-
ReifyRankedShapedTypeOpInterface op) {
36+
/// Reifies the results of `op`, potentially replacing `op` with a reified
37+
/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
38+
/// otherwise it always succeeds. Users of this transform should always expect
39+
/// it to modify the IR, even when it fails. If any of the result types changes,
40+
/// the transform will insert cast operations to the old type to keep the IR
41+
/// consistent.
42+
static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
43+
ReifyRankedShapedTypeOpInterface op) {
3944
LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
4045
// Get the reified out shapes.
4146
ReifiedRankedShapedTypeDims reifiedResultShapes;
@@ -93,6 +98,11 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
9398
// We now have outTypes that need to be turned to cast ops.
9499
Location loc = op->getLoc();
95100
SmallVector<Value> newResults;
101+
// TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
102+
// This is a confluence problem that will need to be addressed.
103+
// For now, we know PadOp and ConcatOp are fine.
104+
assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
105+
"incorrect op");
96106
Operation *newOp = rewriter.clone(*op);
97107
for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
98108
OpResult newRes = newOp->getResult(oldRes.getResultNumber());
@@ -137,8 +147,6 @@ void ReifyResultShapesPass::runOnOperation() {
137147
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
138148
// Handle ops that are not DPS and that do not carry an tied operand shapes.
139149
// For now, limit to tensor::PadOp and tensor::ConcatOp.
140-
if (isa<DestinationStyleOpInterface>(op.getOperation()))
141-
return;
142150
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
143151
return;
144152
ops.push_back(op);

0 commit comments

Comments
 (0)