Skip to content

[mlir][linalg] Take artificial padding into account for pack/unpack folding. #150272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -144,4 +145,17 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"

namespace mlir {
namespace linalg {

/// Returns the outer shape in the packed domain before applying the
/// transposition.
template <typename OpTy,
typename = std::enable_if_t<std::is_same_v<OpTy, linalg::PackOp> ||
std::is_same_v<OpTy, linalg::UnPackOp>>>
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);

} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_IR_LINALG_H
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);

/// Returns true if it is statically known that the `sliceOp` result shape
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);

/// Check if this UnPackOp is like a simple unpad operation.
/// In other words, this operation:
/// 1. drops useless dimensions (dimension of size 1), and
Expand Down
55 changes: 50 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//

template <typename OpTy, typename>
SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getDestType()
: packOrUnPack.getSourceType();
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
SmallVector<int64_t> result(
packedType.getShape().take_front(unpackedType.getRank()));
if (!packOrUnPack.getOuterDimsPerm().empty()) {
applyPermutationToVector(
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
}
return result;
}
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);

// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
Expand Down Expand Up @@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
if (extractSliceUser &&
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
extractSliceUser.getSourceType().getRank() ==
extractSliceUser.getResultType().getRank()) {
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
Expand Down Expand Up @@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}

bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
// Rank-reduced folding is not supported.
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
return false;
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
return false;
RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
for (auto [pos, tileSize] :
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
if (unpackedTypeAfterFold.isDynamicDim(pos))
return false;
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return false;
if (ShapedType::isDynamic(tileSize))
return false;
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedTypeAfterFold.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
}
return true;
}

bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
Expand Down
40 changes: 29 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,33 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();

// Folding is not allowed if it were to introduce artificial padding.
// Folding is also disabled in the case of dynamic dimensions and/or tile
// sizes - that is because it would be impossible to compute the padding
// size and hence to establish whether "artificial" padding would be
// created.
RankedTensorType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
for (auto [pos, tileSize, high] :
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
padOp.getMixedHighPad())) {
if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return failure();
if (ShapedType::isDynamic(tileSize))
return failure();
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
// Do not fold the op if it requires artificial padding.
if (paddingSize + cstHigh.value() >= tileSize)
return failure();
}

rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
Expand Down Expand Up @@ -251,17 +278,8 @@ struct FoldUnpackWithExtractSliceOp
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();

if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
}

// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
if (!unpackOp.canFoldSliceOp(sliceOp))
return failure();

// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Expand Down
75 changes: 64 additions & 11 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1889,31 +1889,84 @@ func.func @fold_cast_unpack_dynamic_tile_size(
// linalg.unpack + tensor.extract_slice
//===----------------------------------------------------------------------===//

func.func @fold_extract_slice_into_unpack(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
return %extracted_slice : tensor<28x28x10xf32>
}
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]

// -----

// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.

// CHECK-LABEL: func @fold_extract_slice_into_unpack
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
// CHECK-SAME: %[[SIZE:.+]]: index
func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> {
%unpack = linalg.unpack %src
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
return %extracted_slice : tensor<28x17x15xf32>
}
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]

// -----

// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.

func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> {
%unpack = linalg.unpack %src
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
return %extracted_slice : tensor<28x16x15xf32>
}
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

func.func @no_fold_extract_slice_into_unpack_dynamic(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
}
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
Expand Down
Loading