@@ -183,19 +183,41 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
183
183
}
184
184
185
185
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186
- let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
186
+ let summary = [{
187
+ Reifies the results of `ReifyRankedShapedTypeOpInterface` operations that
188
+ do not have an explicit output tensor or implicit tied operands.
189
+ }];
187
190
let description = [{
188
- This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189
- operation with ranked `memref` and `tensor` results. Replacing the
190
- operations with their reified versions, and inserting casts when results
191
- shapes are updated.
191
+ This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
192
+ ops with `tensor` results.
193
+
194
+ The pass currently only supports result shape type reification for:
195
+ - tensor::PadOp
196
+ - tensor::ConcatOp
197
+ It addresses a representation gap where implicit op semantics are needed to
198
+ infer static result types from dynamic operands.
199
+ But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
200
+ truth rather than the op itself. As a consequence, this cannot generalize
201
+ today.
202
+
203
+ TODO: in the future, we should consider coupling this information with op
204
+ "transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
205
+ truth that can work across result shape inference, canonicalization and op
206
+ verifiers.
207
+
208
+ The pass replaces the operations with their reified versions, when more
209
+ static information can be derived, and inserts casts when results shapes
210
+ are updated.
192
211
193
212
Example:
194
213
```mlir
195
214
#map = affine_map<(d0) -> (-d0 + 256)>
196
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
215
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
216
+ -> tensor<1x?x64xf32>
217
+ {
197
218
%0 = affine.apply #map(%arg1)
198
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
219
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
220
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
199
221
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200
222
^bb0(%arg3: index, %arg4: index, %arg5: index):
201
223
tensor.yield %arg0 : f32
@@ -205,9 +227,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
205
227
206
228
// mlir-opt --reify-result-shapes
207
229
#map = affine_map<()[s0] -> (-s0 + 256)>
208
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
230
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
231
+ -> tensor<1x?x64xf32>
232
+ {
209
233
%0 = affine.apply #map()[%arg1]
210
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
234
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
235
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
211
236
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212
237
^bb0(%arg3: index, %arg4: index, %arg5: index):
213
238
tensor.yield %arg0 : f32
0 commit comments