Skip to content

[mlir] Fix padding shape computation in PadTilingInterface #149576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yzhang93
Copy link
Contributor

This PR fixes the computation of padded shapes for convolution-style affine maps (e.g., d0 + d1) in PadTilingInterface. Previously, the codes used the direct sum of loop upper bounds, leading to over-padding. For example, the following conv_2d_nhwc_fhwc op, if only padding the c dimensions to multiples of 16, it also incorrectly pads the convolved dimensions and generates the wrong input shape as:

%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 1, 1, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor<1x16x16x4xf32> to tensor<1x17x17x16xf32>
%padded_0 = tensor.pad %arg1 low[0, 0, 0, 0] high[0, 0, 0, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%padded, %padded_0 : tensor<1x17x17x16xf32>, tensor<16x3x3x16xf32>) outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>

The new implementation uses the maximum accessed index as the input for affine map and then adds 1 after aggregating all the terms to get the final padded size. This fixed #148679.

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Vivian Zhang (yzhang93)

Changes

This PR fixes the computation of padded shapes for convolution-style affine maps (e.g., d0 + d1) in PadTilingInterface. Previously, the codes used the direct sum of loop upper bounds, leading to over-padding. For example, the following conv_2d_nhwc_fhwc op, if only padding the c dimensions to multiples of 16, it also incorrectly pads the convolved dimensions and generates the wrong input shape as:

%padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 1, 1, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor&lt;1x16x16x4xf32&gt; to tensor&lt;1x17x17x16xf32&gt;
%padded_0 = tensor.pad %arg1 low[0, 0, 0, 0] high[0, 0, 0, 12] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
  tensor.yield %cst : f32
} : tensor&lt;16x3x3x4xf32&gt; to tensor&lt;16x3x3x16xf32&gt;
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense&lt;1&gt; : tensor&lt;2xi64&gt;, strides = dense&lt;1&gt; : tensor&lt;2xi64&gt;} ins(%padded, %padded_0 : tensor&lt;1x17x17x16xf32&gt;, tensor&lt;16x3x3x16xf32&gt;) outs(%arg2 : tensor&lt;1x14x14x16xf32&gt;) -&gt; tensor&lt;1x14x14x16xf32&gt;
return %0 : tensor&lt;1x14x14x16xf32&gt;

The new implementation uses the maximum accessed index as the input for affine map and then adds 1 after aggregating all the terms to get the final padded size. This fixed #148679.


Full diff: https://github.com/llvm/llvm-project/pull/149576.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp (+13-4)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir (+79-8)
  • (modified) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir (+12-12)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 5eb3761f7aca1..c465383771617 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -114,24 +114,31 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
                             /*compressDims=*/true);
 
       // If we are padding to the next multiple of, compose with ceil(sz) * sz.
+      OpFoldResult paddingDimOfr;
       if (options.padToMultipleOf) {
         AffineExpr d0, s0;
         bindDims(rewriter.getContext(), d0);
         bindSymbols(rewriter.getContext(), s0);
         AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
         AffineMap composedMap = projectedMap.compose(ceilMap);
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, composedMap,
             {indexingSizes[paddingDim], paddingSize},
             /*composeAffineMin=*/true);
-        terms.push_back(paddingDimOfr);
       } else {
         // Otherwise just set to paddingSize.
-        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+        paddingDimOfr = affine::makeComposedFoldedAffineApply(
             rewriter, loc, projectedMap, paddingSize);
-        terms.push_back(paddingDimOfr);
       }
 
+      // Adjust for the maximum accessed index which is (padding_size - 1).
+      AffineExpr d0;
+      bindDims(rewriter.getContext(), d0);
+      AffineMap subtractOneMap = AffineMap::get(1, 0, d0 - 1);
+      OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, subtractOneMap, {paddingDimOfr});
+      terms.push_back(maxAccessIdx);
+
       LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
     }
 
@@ -148,6 +155,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
     AffineExpr sumExpr = dims.front();
     for (unsigned i = 1; i < dims.size(); ++i)
       sumExpr = sumExpr + dims[i];
+    // Add 1 to the maximum accessed index and get the final padded size.
+    sumExpr = sumExpr + rewriter.getAffineConstantExpr(1);
     OpFoldResult paddedDimOfr =
         affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
     paddedShape[resultIndex] = paddedDimOfr;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b682673e..53cb7d7767b9a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<9x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<9x15x13xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,74 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+//     CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+  //      CHECK:   : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+  //      CHECK:   : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+  return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+//     CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+  //  CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  //      CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+  //      CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+  //      CHECK:   : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+  //      CHECK:   : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+  //      CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  //      CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+  //      CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+  //      CHECK:   : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+  //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+  // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+  %0 = linalg.conv_2d_nhwc_fhwc
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+      ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+    outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+  return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+      padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed309c05..f7418769f79ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
 
 // CHECK-LABEL: @generic
 // CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
-  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x11xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
 
   //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
 
   //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
   //      CHECK:   : tensor<7x5xf32> to tensor<8x5xf32>
   //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
-  //      CHECK:   : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<7x11x11xf32> to tensor<8x14x12xf32>
   // CHECK-NEXT: linalg.generic
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
-  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
     ^bb0(%in: f32, %out: f32):
       linalg.yield %in : f32
-    } -> tensor<7x11x12xf32>
-    return %0 : tensor<7x11x12xf32>
+    } -> tensor<7x11x11xf32>
+    return %0 : tensor<7x11x11xf32>
   }
   module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
 
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
   //      CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
   //      CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
   //      CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
-  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+  //      CHECK:   : tensor<?x11x?xf32> to tensor<8x14x12xf32>
   //
   //      CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
   //      CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
-  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
-  //      CHECK: } -> tensor<8x14x13xf32>
-  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+  //      CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+  //      CHECK: } -> tensor<8x14x12xf32>
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
   //
   %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
     ^bb0(%in: f32, %out: f32):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][PadTilingInterface] Incorrect input padding for convolutions
2 participants