@@ -33,9 +33,14 @@ namespace memref {
33
33
34
34
using namespace mlir ;
35
35
36
- LogicalResult
37
- mlir::memref::reifyOpResultShapes (RewriterBase &rewriter,
38
- ReifyRankedShapedTypeOpInterface op) {
36
+ // / Reifies the results of `op`, potentially replacing `op` with a reified
37
+ // / version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
38
+ // / otherwise it always succeeds. Users of this transform should always expect
39
+ // / it to modify the IR, even when it fails. If any of the result types changes,
40
+ // / the transform will insert cast operations to the old type to keep the IR
41
+ // / consistent.
42
+ static LogicalResult reifyOpResultShapes (RewriterBase &rewriter,
43
+ ReifyRankedShapedTypeOpInterface op) {
39
44
LLVM_DEBUG ({ DBGS () << " reifying op: " << op << " \n " ; });
40
45
// Get the reified out shapes.
41
46
ReifiedRankedShapedTypeDims reifiedResultShapes;
@@ -93,6 +98,11 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
93
98
// We now have outTypes that need to be turned to cast ops.
94
99
Location loc = op->getLoc ();
95
100
SmallVector<Value> newResults;
101
+ // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
102
+ // This is a confluence problem that will need to be addressed.
103
+ // For now, we know PadOp and ConcatOp are fine.
104
+ assert ((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation ())) &&
105
+ " incorrect op" );
96
106
Operation *newOp = rewriter.clone (*op);
97
107
for (auto [reifiedTy, oldRes] : llvm::zip (outTypes, op->getResults ())) {
98
108
OpResult newRes = newOp->getResult (oldRes.getResultNumber ());
@@ -137,8 +147,6 @@ void ReifyResultShapesPass::runOnOperation() {
137
147
getOperation ()->walk ([&](ReifyRankedShapedTypeOpInterface op) {
138
148
// Handle ops that are not DPS and that do not carry an tied operand shapes.
139
149
// For now, limit to tensor::PadOp and tensor::ConcatOp.
140
- if (isa<DestinationStyleOpInterface>(op.getOperation ()))
141
- return ;
142
150
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation ()))
143
151
return ;
144
152
ops.push_back (op);
0 commit comments