Skip to content

[mlir][linalg] Take artificial padding into account for pack/unpack folding. #150272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jul 23, 2025

The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. It is not valid to have artificial padding.

The documentation improvement and verifier update will be done in a separate PR (i.e., #149624). The revision is a step towards it.

…olding.

The revision only folds the tensor.pad/extract_slice op into
linalg.pack/unpack ops only when it is safe to fold.

According to the doc, it is not valid to have artificial padding.

```
- The following relationship for the tiled dimensions holds:
shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i].
```

The documentation improvement and verifier update will be done in a
separate PR (i.e., #149624).
The revision is a step towards it.

Signed-off-by: hanhanW <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold.

According to the doc, it is not valid to have artificial padding.

- The following relationship for the tiled dimensions holds:
shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i].

The documentation improvement and verifier update will be done in a separate PR (i.e., #149624). The revision is a step towards it.


Full diff: https://github.com/llvm/llvm-project/pull/150272.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+6)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+50-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+27-11)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+27-10)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+44-16)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..6941939c8db5a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_IR_LINALG_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
                                int64_t dim);
 
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..fa572024ff72b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
                                    ArrayRef<int64_t> innerPermutation,
                                    ArrayRef<int64_t> outerPermutation);
 
+    /// Returns true if it is statically known that the `sliceOp` result shape
+    /// is compatible with the `unPackOp`. I.e., it does not drop any tile.
+    bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
+
     /// Check if this UnPackOp is like a simple unpad operation.
     /// In other words, this operation:
     /// 1. drops useless dimensions (dimension of size 1), and
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..046a73c90f110 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
+  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                      ? packOrUnPack.getSourceType()
+                                      : packOrUnPack.getDestType();
+  SmallVector<int64_t> result(
+      packedType.getShape().take_front(unpackedType.getRank()));
+  if (!packOrUnPack.getOuterDimsPerm().empty()) {
+    applyPermutationToVector(
+        result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+  }
+  return result;
+}
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
 // updated mixed-tile-sizes attribute. A tile size is updated only
 // when:
@@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   if (unPackOp->hasOneUse()) {
     auto extractSliceUser =
         dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
-    if (extractSliceUser &&
-        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
-        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
-        extractSliceUser.getSourceType().getRank() ==
-            extractSliceUser.getResultType().getRank()) {
+    if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
       OpBuilder::InsertionGuard g(rewriter);
       rewriter.setInsertionPoint(unPackOp);
       auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   return failure();
 }
 
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+  // Rank-reduced folding is not supported.
+  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+    return false;
+  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+      !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+    return false;
+  RankedTensorType unpackedType = sliceOp.getResultType();
+  SmallVector<int64_t> outerShapeWithoutTranspose =
+      getPackedOuterShapeWithoutTransposition(*this);
+  for (auto [pos, tileSize] :
+       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+    if (unpackedType.isDynamicDim(pos))
+      return false;
+    if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+      return false;
+    if (ShapedType::isDynamic(tileSize))
+      return false;
+    int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                          unpackedType.getDimSize(pos);
+    if (paddingSize >= tileSize)
+      return false;
+  }
+  return true;
+}
+
 bool UnPackOp::isLikeUnPad() {
   RankedTensorType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9b71c2a..73e157b42235a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
         return failure();
 
+    // Folding is not allowed if it introduces artificial padding. It is not
+    // safe to fold the ops if any dynamic dimension or tile size is present,
+    // because we can not infer the padding size.
+    RankedTensorType unpackedType = packOp.getSourceType();
+    SmallVector<int64_t> outerShapeWithoutTranspose =
+        getPackedOuterShapeWithoutTransposition(packOp);
+    for (auto [pos, tileSize, high] :
+         llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+                         padOp.getMixedHighPad())) {
+      if (unpackedType.isDynamicDim(pos))
+        return failure();
+      if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+        return failure();
+      if (ShapedType::isDynamic(tileSize))
+        return failure();
+      std::optional<int64_t> cstHigh = getConstantIntValue(high);
+      if (!cstHigh)
+        return failure();
+      int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                            unpackedType.getDimSize(pos);
+      // Do not fold the op if it requires artificial padding.
+      if (paddingSize + cstHigh.value() >= tileSize)
+        return failure();
+    }
+
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
         packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
     if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
       return failure();
 
-    if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "rank-reduced folding is not supported");
-    }
-
-    // Check all offsets are zeros, and all strides are ones.
-    if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
-        !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "expects offsets to be 0s and strides to be 1s");
-    }
+    if (!unpackOp.canFoldSliceOp(sliceOp))
+      return failure();
 
     // Create a new empty output tensor.
     Type elementType = unpackOp.getDestType().getElementType();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..cd14bc3d1948b 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
 //===----------------------------------------------------------------------===//
 
 func.func @fold_extract_slice_into_unpack(
-    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+    %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x28x10xf32> {
   %unpack = linalg.unpack %src
       outer_dims_perm = [0, 1, 2]
       inner_dims_pos = [1, 2]
       inner_tiles = [16, 16]
-      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
   %extracted_slice = tensor.extract_slice %unpack
-      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
-  return %extracted_slice : tensor<28x28x?xf32>
+      [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+  return %extracted_slice : tensor<28x28x10xf32>
 }
-
 // CHECK-LABEL: func @fold_extract_slice_into_unpack
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
-//  CHECK-SAME:     %[[SIZE:.+]]: index
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[SIZE:[a-zA-Z0-9]+]]
 //       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, 10] [1, 1, 1]
 //       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
 //  CHECK-SAME:       into %[[DEST_SLICE]]
 //       CHECK:   return %[[UNPACK]]
 
 // -----
 
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
 func.func @no_fold_extract_slice_into_unpack_rank_reducing(
     %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
 ) -> tensor<28xf32> {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..4a97d1df25f15 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,32 @@
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack  %s | FileCheck %s
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control  %s | FileCheck %s --check-prefix=CONTROL
 
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
+  %empty = tensor.empty() : tensor<16656x16xf32>
+  %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
+  %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
+  return %1 : tensor<16649x16xf32>
+}
+// CHECK-LABEL: func @fold_unpack_slice(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+//      CHECK:    %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
+//      CHECK:    %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+// CHECK-SAME:        into %[[INIT]]
+//      CHECK:    return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
   %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
       : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
   %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-//      CHECK: func @fold_unpack_slice(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-//      CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME:       into %[[INIT]]
-//      CHECK:   return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
 
 // -----
 
@@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
 
 // -----
 
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[7, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  } : tensor<16649x16xf32> to tensor<16656x16xf32>
   %empty = tensor.empty() : tensor<2082x1x8x32xf32>
   %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
       : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
 
 // -----
 
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
   } : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
       : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
   return %pack : tensor<2082x1x8x32xf32>
 }
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
+// CHECK:         tensor.pad
+// CHECK:         linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16649x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack(
 // CHECK:         tensor.pad
 // CHECK:         linalg.pack
 

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just some minor questions.

Signed-off-by: hanhanW <[email protected]>
@hanhanW hanhanW requested a review from banach-space July 24, 2025 02:16
Comment on lines 5534 to 5537
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a test where a non-trailing dim is "sliced", e.g.

func.func @fold_extract_slice_into_unpack(
    %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
) -> tensor<28x16x15xf32> {
  %unpack = linalg.unpack %src
      outer_dims_perm = [0, 1, 2]
      inner_dims_pos = [1, 2]
      inner_tiles = [16, 16]
      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
  %extracted_slice = tensor.extract_slice %unpack
      [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
  return %extracted_slice : tensor<28x16x15xf32>
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if I'm not missing something, I don't see tests for when we hit this case where we would need artificial padding. It would be nice to add these negative cases :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I missed that!

I think @banach-space's example is a negative test. Anyway, I added two more tests that should address your comments.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks fine

Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, a couple things that other reviewers mentioned and one missing test case I guess.

Comment on lines 5534 to 5537
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if I'm not missing something, I don't see tests for when we hit this case where we would need artificial padding. It would be nice to add these negative cases :)

Copy link
Contributor

@egebeysel egebeysel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment about a test case, after that LGTM!


// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.

func.func @no_fold_extract_slice_into_unpack_artificial_padding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test case :) I guess we could add a similar one over to the fold-into-pack-unpack-patterns :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Great work @hanhanW 🙏🏻

@hanhanW
Copy link
Contributor Author

hanhanW commented Jul 24, 2025

Thanks for the review, really appreciate it! 🙏🏻

@hanhanW hanhanW merged commit 1ff6d9d into main Jul 24, 2025
9 checks passed
@hanhanW hanhanW deleted the users/hanhanW/restrict-pack-unpack-folding-patterns branch July 24, 2025 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants