-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][linalg] Restrict linalg.pack to not have artificial padding. #149624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Restrict linalg.pack to not have artificial padding. #149624
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
59aa079
to
1ba3c59
Compare
Thanks for drafting and sharing this, @hanhanW ! Let me make sure that I understand correctly :) Below is your original example with a small modification - the input tensor has 9 rather than 8 elements. This way, we need exactly two tiles to pack it. func.func @foo(%src: tensor<9xf32>) -> tensor<2x8xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dest = tensor.empty() : tensor<2x8xf32>
%pack = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0]
inner_tiles = [8] into %dest
: tensor<8xf32> -> tensor<2x8xf32>
return %pack : tensor<2x8xf32>
} The 2nd tile will require padding, so // x - original value from %src
// * = %cst
x x x x x x x x
x * * * * * * * This fine and how I expect Now, any tile beyond those 2 tiles above (so, e.g., the additional 98 tiles "requested" in your original example) would lead to even more "artificial" padding ("artificial" as not strictly required by I agree that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can give a real review later, but the idea of this sounds good to me. Having a linalg.pack with a whole bunch of extra padding is limiting for fusions, and I expect it would be simpler to deal with the pad and pack separately in that case.
Perhaps a separate discussion, but would the same logic apply to linalg.unpack with its extract_slice semantics too? It doesn't seem to me like unpack has the same problems as pack, but do we want to keep the semantics consistent?
The difficulty with pack is that the padding semantics are hard to handle with fusions. Fusing a consumer pad is hard because you may need to produce a result that is larger than the total output spanned by the tiling loop. This means you need to either expand the iteration space of the loop or have some iterations of the loop produce some extra results.
However, the same difficulty does not really arise with tensor.extract_slice, since no new values are created. This means that the problems we see for linalg.pack don't exist for linalg.unpack, so I'd probably argue that we should not update the unpack semantics to match these new pack semantics, but WDYT?
Yes, it is what I meant; you have better example! I'll modify the PR description a bit. |
func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { | ||
%empty = tensor.empty() : tensor<8x4x16x8xf32> | ||
func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(%arg0: tensor<32x64xf32>) -> tensor<4x4x16x8xf32> { | ||
%empty = tensor.empty() : tensor<4x4x16x8xf32> | ||
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> | ||
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> | ||
return %pack : tensor<8x4x16x8xf32> | ||
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<4x4x16x8xf32> | ||
return %pack : tensor<4x4x16x8xf32> | ||
} | ||
// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( | ||
// CHECK-LABEL: func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate( | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> | ||
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] | ||
// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> | ||
// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]] | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x64x8xf32> | ||
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]] | ||
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] | ||
// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> | ||
// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@adam-smnk do you remember why we added this test? It looks propagable to me after I fix the shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old shape had padding (without padding_value
so UB from pack
perspective but allowed) so, the transform fails to reassociate dims.
With the new shape and no padding, it should be fine to propagate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I thought that should be caught by verifier, but I'm wrong. It looks like a bug in verifier to me. Anyway, we should be okay with the new semantics. Do we keep the test in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like a bug in verifier to me.
I think it could/should be improved. Right now, IMO op docs allow it or perhaps somewhat omit it.
Do we keep the test in this case?
If there's another test case with padding then no need for this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to remove the test case because it is covered by the below test. I'll update the PR description later.
llvm-project/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Lines 1091 to 1105 in 45a6c02
func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> { | |
%empty = tensor.empty() : tensor<8x2x32x16x4xf32> | |
%expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32> | |
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32> | |
return %pack : tensor<8x2x32x16x4xf32> | |
} | |
// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand( | |
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] | |
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32> | |
// CHECK: %[[PACK:.+]] = linalg.pack | |
// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] | |
// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32> | |
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] | |
// CHECK-SAME: output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32> | |
// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the overall concept.
Creating extra tiles filled only with padding value does not seem critical to the packing itself. The same behavior should be achievable with a separate, follow-up pad. So, the current pad semantics increase ops' complexity with no clear payoff.
The PR description could use clearer wording. extra padding
is too ambiguous as packing requires some amount of padding.
Once new desired semantics are fleshed out, the linalg.pack
op description should be updated too.
Thanks for the feedback! It looks like the active users and contributors (that I know) are okay with the change, so I'll go ahead to update the docs and corresponding patterns. E.g., the patterns of folding pad/extract_slice into pack/unpack ops should take it into account. Otherwise, invalid IRs are generated. Please help ping others if I missed someone that would be interested in this. cc @egebeysel who recently looks at the ops. |
I'll borrow the wordings from @banach-space which looks better to me, thanks!
|
Thanks for the ping! I'll take a look at this tomorrow :) |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-linalg Author: Han-Chung Wang (hanhanW) ChangesThe revision restrict the func.func @<!-- -->foo(%src: tensor<9xf32>) -> tensor<100x8xf32> {
%cst = arith.constant 0.000000e+00 : f32
%dest = tensor.empty() : tensor<100x8xf32>
%pack = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0]
inner_tiles = [8] into %dest
: tensor<9xf32> -> tensor<100x8xf32>
return %pack : tensor<100x8xf32>
} IMO, it is a misuse if we use pack ops with artificial padding sizes because the intention of the pack op is to relayout the source based on target intrinsics, etc. The output shape is expected to be This also makes consumer tiling much easier because the consumer fusion does not support artificial padding sizes. It is very hard to make it work without using ad-hoc patterns because the tiling sizes are about source, which implies that you don't have a core_id/thread_id to write padding values to the whole tile. People may have a question how why pad tiling implementation works. The answer is that it creates an Removed tests:
New changes in PackAndUnpackPatterns.cpp:
The other changes in lit tests are just fixing the shape. Patch is 29.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149624.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..c1a96d5eb1dbe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -150,9 +150,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
`padding_value` specifies a padding value at the boundary on non-perfectly
divisible dimensions. Padding is optional:
- - If absent, it is UB if the tile does not perfectly divide the dimension.
+ - If absent, it assumes the tile perfectly divides the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete.
+ tile complete. Note that it is not allowed to have artificial padding that
+ is not strictly required by linalg.pack.
Example:
```mlir
@@ -167,6 +168,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
//
// Note: Only tiled dimensions can be padded.
```
+
+ Invalid example that has artificial padding:
+ ```mlir
+ %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
+ inner_tiles = [8] into %dest
+ : tensor<9xf32> -> tensor<3x8xf32>
+ // \
+ // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
+ ```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..248cefc5d707f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -4599,22 +4600,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4673,11 +4658,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4694,6 +4674,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("the shape of unpacked domain value is not large "
+ "enough to hold the packed data. Expected at least ")
+ << expectedPackedType << ", got " << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9b71c2a..cac77e45e8575 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -194,6 +195,28 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
}
};
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+static SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ SmallVector<int64_t> result(
+ packedType.getShape().take_front(unpackedType.getRank()));
+ if (!packOrUnPack.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+ }
+ return result;
+}
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -220,6 +243,29 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
+ // Folding is not allowed if it introduces artificial padding.
+ RankedTensorType unpackedType = packOp.getSourceType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(packOp);
+ for (auto [pos, tileSize, high] :
+ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad())) {
+ if (unpackedType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ std::optional<int64_t> cstHigh = getConstantIntValue(high);
+ if (!cstHigh)
+ return failure();
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ // Do not fold the op if it requires artificial padding.
+ if (paddingSize + cstHigh.value() >= tileSize)
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
@@ -263,6 +309,24 @@ struct FoldUnpackWithExtractSliceOp
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
+ // Folding is not allowed if any tile is dropped.
+ RankedTensorType unpackedType = sliceOp.getResultType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(unpackOp);
+ for (auto [pos, tileSize] : llvm::zip_equal(
+ unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles())) {
+ if (unpackedType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ if (paddingSize >= tileSize)
+ return failure();
+ }
+
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Value output = rewriter.create<tensor::EmptyOp>(
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..dfe3bfd4a967a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1387,42 +1387,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
-
// -----
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
@@ -1430,8 +1431,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
- into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f152f4e..cc26fa48abf4b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
- %empty = tensor.empty() : tensor<8x4x16x8xf32>
- %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
- %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
- return %pack : tensor<8x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
-// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
-
-// -----
-
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7b6a624..4299a15026f91 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}
// -----
+
func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1827,24 +1828,24 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+ // expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
// -----
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
- return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+ // expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+ %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+ return %0 : tensor<8x7x16x32xf32>
}
// -----
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
- // expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+ // expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8a947d7..9e7681d1a1b7d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
- -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+ -> tensor<265x12x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
- // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
+ // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
- // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+ // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
- // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
- : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
- return %0 : tensor<265x16x16x1xf32>
+ : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+ return %0 : tensor<265x12x16x1xf32>
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..4a97d1df25f15 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,32 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
+ %empty = tensor.empty() : tensor<16656x16xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
+ %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
+ return %1 : tensor<16649x16xf32>
+}
+// CHECK-LABEL: func @fold_unpack_slice(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_dynamic_unpack_...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refining semantics like this is always a much appreciated improvement, thank you!
Some minor comments inline. And a request for a separate PR :)
tile complete. Note that it is not allowed to have artificial padding that | ||
is not strictly required by linalg.pack (i.e., padding past what is needed | ||
to complete the last tile along each packed dimension).. It is UB if extra | ||
padding is requested for dynamic cases. For static cases, they are caught | ||
by the verifier. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really cover "what's not supported" in Op docs, I would skip this part.
Btw, the current docs already imply that "artificial" padding is not supported, see e.g.
- The following relationship for the tiled dimensions holds:
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
In general, I would focus on what is supported. Anything that doesn't fit definition should be a verification error.
In essence, you are making the implementation match the existing description 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After I read the statement again, I think it implies that artificial padding is not valid if we replace /
with CeilDiv? Thus, I think the statement is still valuable. How about I trim away the dynamic and static part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you are right.
/
means FloorDiv
and to me that looks incorrect, it should be CeilDiv
instead.
Thus, I think the statement is still valuable. How about I trim away the dynamic and static part?
If you find it valuable then lets keep, these things are very subjective and there's non-zero likelihood that others find it helpful as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, how about updating the existing docs. I am referring to this specifically:
llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Lines 108 to 109 in 85349b4
- The following relationship for the tiled dimensions holds: | |
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`. |
\
with CeilDiv
).
|
||
// ----- | ||
|
||
func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why "slicing"? To me it's just a case of the result shape being invalid 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more like the unpack version of artificial padding
, where it drops some tiles. I was choosing the words between slice
and drop
, and somehow thought slice
was better because it has slicing semantics. Now I think drop
is a better terminology. Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! I would be inclined towards re-using "artificial" here - mostly for consistency in terminology. So, perhaps: @unpack_with_artifical_tiles_that_are_dropped
🤷🏻 Naming is hard!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, let's use artificial
for consistency.
I've been thinking the same thing, but it looks like a neutral case to me. We can have a separate PR for the folding patterns, and the current PR can be on top of the other one. However, I found struggles about writing PR description because people may ask -- why do we restrict the folding patterns if they are valid atm? This is one of the reasons that I ended up with having these changes within a single PR. I may answer that the separate PR is a preparation for refining the op semantics, as long as people don't ask why not land them together. The other reason is that the PR size is reasonable to me. There are three main changes:
As you (and others) are the main reviewers so far, I'm happy to follow what makes you comfortable. Any preference? |
Oh, I just saw your other comment which implies that there is a bug in the current doc. So having a separate PR sounds okay to me now. Let me split the changes out.
|
…olding. The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. According to the doc, it is not valid to have artificial padding. ``` - The following relationship for the tiled dimensions holds: shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]. ``` The documentation improvement and verifier update will be done in a separate PR (i.e., llvm#149624). The revision is a step towards it. Signed-off-by: hanhanW <[email protected]>
4f2247f
to
68b02a0
Compare
It looks like Github can't inline the comments properly even if I do my best to split the commits. The folding parts are split out to #150127; there are three commits in the PR -- main commit + two rounds of review comments: @banach-space you still can see your comments if you click into an individual commit, FYI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think you should drop @2548809 from this PR, so we could handle #150127 separately :) Or target that branch from this one if you're planning to get that one in first. Otherwise LGTM.
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
68b02a0
to
5a1ae04
Compare
I need to land the other one first. Otherwise, the CI is not happy. I recreate the other PR because the branch needs to live in upstream and it has to associate to a PR: #150272
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
I think that having two separate PRs makes more sense. Code-wise, these are relatively small patches. But the underlying logic that's updated is a bit fiddly and every change deserve a dedicated commit summary :)
I've left some minor comments. I will make sure to prioritise this tomorrow so that you can land it promptly.
to complete the last tile along each packed dimension). It is UB if extra | ||
padding is requested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is UB if extra padding is requested.
Shouldn't that be verification error? And then restore UB
for the previous point, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't that be verification error?
It's not possible to enforce that with dynamic source.
UB
is more of "catch all" here and allows linalg::lowerPack
to remain as is.
restore UB for the previous point
It could remain there to reinforce the message. But no strong preference here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforcing that is more like a dynamic check, but not a static check. I.e., it can only be asserted during runtime for most cases, IMO. E.g., you don't do any out-of-bound checks for tensor.extract_slice for dynamic sizes/dims, but you do out-of-bound checks when the op has static sizes and offsets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restored the UB for the previous point, because it also happens in dynamic shape that I forgot. I also added one more sentence about why they are UB. Does it look better?
tile complete. Note that it is not allowed to have artificial padding that | ||
is not strictly required by linalg.pack (i.e., padding past what is needed | ||
to complete the last tile along each packed dimension).. It is UB if extra | ||
padding is requested for dynamic cases. For static cases, they are caught | ||
by the verifier. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you are right.
/
means FloorDiv
and to me that looks incorrect, it should be CeilDiv
instead.
Thus, I think the statement is still valuable. How about I trim away the dynamic and static part?
If you find it valuable then lets keep, these things are very subjective and there's non-zero likelihood that others find it helpful as well.
|
||
// ----- | ||
|
||
func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! I would be inclined towards re-using "artificial" here - mostly for consistency in terminology. So, perhaps: @unpack_with_artifical_tiles_that_are_dropped
🤷🏻 Naming is hard!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overall direction and the implementation make sense to me, thanks for this improvement!
I've left a few minor comments, but this are non-blocking. Since this is restricting the Op semantics (and since others commented here as well), could you wait for one more "+1" before landing? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also update:
- The following relationship for the tiled dimensions holds:
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
as
- The following relationship for the tiled dimensions holds:
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] ⌈/⌉ inner_tiles[i]` (⌈/⌉ - CeilDiv).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, updated. Thanks for pointing it out!
tile complete. Note that it is not allowed to have artificial padding that | ||
is not strictly required by linalg.pack (i.e., padding past what is needed | ||
to complete the last tile along each packed dimension).. It is UB if extra | ||
padding is requested for dynamic cases. For static cases, they are caught | ||
by the verifier. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, how about updating the existing docs. I am referring to this specifically:
llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Lines 108 to 109 in 85349b4
- The following relationship for the tiled dimensions holds: | |
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`. |
\
with CeilDiv
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, great change 👍
to complete the last tile along each packed dimension). It is UB if extra | ||
padding is requested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't that be verification error?
It's not possible to enforce that with dynamic source.
UB
is more of "catch all" here and allows linalg::lowerPack
to remain as is.
restore UB for the previous point
It could remain there to reinforce the message. But no strong preference here.
Signed-off-by: hanhanW <[email protected]>
…olding. (#150272) The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. It is not valid to have artificial padding. The documentation improvement and verifier update will be done in a separate PR (i.e., #149624). The revision is a step towards it. --------- Signed-off-by: hanhanW <[email protected]>
I think I have @banach-space and @adam-smnk could you approve the other one? |
…ck/unpack folding. (#150272) The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. It is not valid to have artificial padding. The documentation improvement and verifier update will be done in a separate PR (i.e., llvm/llvm-project#149624). The revision is a step towards it. --------- Signed-off-by: hanhanW <[email protected]>
The revision restrict the
linalg.pack
op to not have artificial padding semantics. E.g., the below is valid without the change, and it becomes invalid with the change.IMO, it is a misuse if we use pack ops with artificial padding sizes because the intention of the pack op is to relayout the source based on target intrinsics, etc. The output shape is expected to be
tensor<2x8xf32>
. If people need extra padding sizes, they can create a new pad op followed by the pack op.This also makes consumer tiling much easier because the consumer fusion does not support artificial padding sizes. It is very hard to make it work without using ad-hoc patterns because the tiling sizes are about source, which implies that you don't have a core_id/thread_id to write padding values to the whole tile.
People may have a question how why pad tiling implementation works. The answer is that it creates an
if-else
branch to handle the case. In my experience, it is very struggle in transformation because most of the time people only need one side of the branch given that the tile sizes are usually greater than padding sizes. However, the implementation is conservatively correct in terms of semantics. Given that the introduction ofpack
op is to serve the relayout needs better, having the restriction makes sense to me.Removed tests:
no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate
fromdata-layout-propagation.mlir
: it is a dup test tobubble_up_pack_non_expanded_dims_through_expand
after we fix the shape.fuse_pack_consumer_with_untiled_extra_padding
fromtile-and-fuse-consumer.mlir
: it was created for artificial padding in the consumer fusion implementation.The other changes in lit tests are just fixing the shape.