Skip to content

Commit 1497991

Browse files
Update pass documentation
1 parent ba51026 commit 1497991

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,41 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
183183
}
184184

185185
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186-
let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
186+
let summary = [{
187+
Reifies the results of `ReifyRankedShapedTypeOpInterface` operations that
188+
do not have an explicit output tensor or implicit tied operands.
189+
}];
187190
let description = [{
188-
This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189-
operation with ranked `memref` and `tensor` results. Replacing the
190-
operations with their reified versions, and inserting casts when results
191-
shapes are updated.
191+
This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
192+
ops with `tensor` results.
193+
194+
The pass currently only supports result shape type reification for:
195+
- tensor::PadOp
196+
- tensor::ConcatOp
197+
It addresses a representation gap where implicit op semantics are needed to
198+
infer static result types from dynamic operands.
199+
But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
200+
truth rather than the op itself. As a consequence, this cannot generalize
201+
today.
202+
203+
TODO: in the future, we should consider coupling this information with op
204+
"transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
205+
truth that can work across result shape inference, canonicalization and op
206+
verifiers.
207+
208+
The pass replaces the operations with their reified versions, when more
209+
static information can be derived, and inserts casts when results shapes
210+
are updated.
192211

193212
Example:
194213
```mlir
195214
#map = affine_map<(d0) -> (-d0 + 256)>
196-
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
215+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
216+
-> tensor<1x?x64xf32>
217+
{
197218
%0 = affine.apply #map(%arg1)
198-
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
219+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
220+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
199221
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200222
^bb0(%arg3: index, %arg4: index, %arg5: index):
201223
tensor.yield %arg0 : f32
@@ -205,9 +227,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
205227

206228
// mlir-opt --reify-result-shapes
207229
#map = affine_map<()[s0] -> (-s0 + 256)>
208-
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
230+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
231+
-> tensor<1x?x64xf32>
232+
{
209233
%0 = affine.apply #map()[%arg1]
210-
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
234+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
235+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
211236
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212237
^bb0(%arg3: index, %arg4: index, %arg5: index):
213238
tensor.yield %arg0 : f32

mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "llvm/Support/InterleavedRange.h"
2223

@@ -134,8 +135,10 @@ struct ReifyResultShapesPass final
134135
void ReifyResultShapesPass::runOnOperation() {
135136
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
136137
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
137-
// Some ops have rigid type checkers and need to update their operands.
138-
// Only admit the ones that are explicitly supported for now.
138+
// Handle ops that are not DPS and that do not carry an tied operand shapes.
139+
// For now, limit to tensor::PadOp and tensor::ConcatOp.
140+
if (isa<DestinationStyleOpInterface>(op.getOperation()))
141+
return;
139142
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
140143
return;
141144
ops.push_back(op);

0 commit comments

Comments
 (0)