Skip to content

Commit 59aa079

Browse files
committed
[mlir][linalg] Restrict linalg.pack to not have extra padding sizes.
Signed-off-by: hanhanW <[email protected]>
1 parent 9878ef3 commit 59aa079

File tree

8 files changed

+102
-133
lines changed

8 files changed

+102
-133
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4601,8 +4601,8 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
46014601

46024602
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
46034603
/// of the `limitShape`.
4604-
static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4605-
ArrayRef<int64_t> limitShape) {
4604+
static bool isCompatibleShape(ArrayRef<int64_t> sourceShape,
4605+
ArrayRef<int64_t> limitShape) {
46064606
assert(
46074607
sourceShape.size() == limitShape.size() &&
46084608
"expected source shape rank, and limit of the shape to have same rank");
@@ -4611,7 +4611,7 @@ static bool areAllInBound(ArrayRef<int64_t> sourceShape,
46114611
int64_t sourceExtent = std::get<0>(it);
46124612
int64_t limit = std::get<1>(it);
46134613
return ShapedType::isDynamic(sourceExtent) ||
4614-
ShapedType::isDynamic(limit) || sourceExtent <= limit;
4614+
ShapedType::isDynamic(limit) || sourceExtent == limit;
46154615
});
46164616
}
46174617

@@ -4673,11 +4673,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46734673
// represents full tiles.
46744674
RankedTensorType expectedPackedType = PackOp::inferPackedType(
46754675
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4676-
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4677-
return op->emitError("the shape of output is not large enough to hold the "
4678-
"packed data. Expected at least ")
4679-
<< expectedPackedType << ", got " << packedType;
4680-
}
46814676
if (!llvm::all_of(
46824677
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
46834678
mixedTiles),
@@ -4694,6 +4689,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46944689
return op->emitError("mismatch in inner tile sizes specified and shaped of "
46954690
"tiled dimension in the packed type");
46964691
}
4692+
if (!isCompatibleShape(expectedPackedType.getShape(),
4693+
packedType.getShape())) {
4694+
return op->emitError("the shape of output is not large enough to hold the "
4695+
"packed data. Expected at least ")
4696+
<< expectedPackedType << ", got " << packedType;
4697+
}
46974698
return success();
46984699
}
46994700

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1212
#include "mlir/Dialect/Utils/IndexingUtils.h"
13+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1314
#include "mlir/IR/PatternMatch.h"
1415

1516
namespace mlir {
@@ -220,6 +221,34 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220221
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221222
return failure();
222223

224+
RankedTensorType srcType = packOp.getSourceType();
225+
RankedTensorType destType = packOp.getDestType();
226+
SmallVector<int64_t> outerShapeWithoutTranspose(
227+
destType.getShape().take_front(srcType.getRank()));
228+
if (!packOp.getOuterDimsPerm().empty()) {
229+
applyPermutationToVector(
230+
outerShapeWithoutTranspose,
231+
invertPermutationVector(packOp.getOuterDimsPerm()));
232+
}
233+
for (auto [pos, tileSize, high] :
234+
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
235+
padOp.getMixedHighPad())) {
236+
if (srcType.isDynamicDim(pos))
237+
return failure();
238+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
239+
return failure();
240+
if (ShapedType::isDynamic(tileSize))
241+
return failure();
242+
std::optional<int64_t> cstHigh = getConstantIntValue(high);
243+
if (!cstHigh)
244+
return failure();
245+
int64_t paddingSize =
246+
outerShapeWithoutTranspose[pos] * tileSize - srcType.getDimSize(pos);
247+
// Do not fold the ops if it requires extra padding sizes.
248+
if (paddingSize + cstHigh.value() >= tileSize)
249+
return failure();
250+
}
251+
223252
rewriter.replaceOpWithNewOp<PackOp>(
224253
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
225254
packOp.getMixedTiles(), constantPaddingValue,

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,51 +1387,52 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
13871387
// CHECK-LABEL: @recursive_effect
13881388
// CHECK: linalg.map
13891389

1390+
// -----
1391+
13901392
//===----------------------------------------------------------------------===//
13911393
// linalg.pack
13921394
//===----------------------------------------------------------------------===//
13931395

13941396
// CHECK-LABEL: func @fold_pack_constant_splat
13951397
// CHECK-NOT: linalg.pack
1396-
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1397-
func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1398+
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1399+
func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
13981400
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
13991401
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1400-
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
1401-
return %0 : tensor<8x16x8x32xf32>
1402+
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
1403+
return %0 : tensor<4x8x8x32xf32>
14021404
}
14031405

14041406
// -----
14051407

14061408
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
14071409
// CHECK-NOT: linalg.pack
1408-
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1409-
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1410+
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1411+
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
14101412
%pad = arith.constant 1.000000e-01 : f32
14111413
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14121414
%0 = linalg.pack %cst
14131415
padding_value(%pad : f32)
14141416
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1415-
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1416-
return %0 : tensor<8x16x8x32xf32>
1417+
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1418+
return %0 : tensor<4x8x8x32xf32>
14171419
}
14181420

1419-
14201421
// -----
14211422

14221423
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
14231424
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14241425
// CHECK: linalg.pack
1425-
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1426+
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
14261427
%pad = arith.constant 0.0 : f32
14271428
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14281429
%0 = linalg.pack %cst
14291430
padding_value(%pad : f32)
14301431
outer_dims_perm = [1, 0]
14311432
inner_dims_pos = [0, 1]
14321433
inner_tiles = [8, 32]
1433-
into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1434-
return %0 : tensor<8x16x8x32xf32>
1434+
into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1435+
return %0 : tensor<4x8x8x32xf32>
14351436
}
14361437

14371438
// -----

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,21 +1295,21 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
12951295

12961296
// -----
12971297

1298-
func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
1299-
%empty = tensor.empty() : tensor<8x4x16x8xf32>
1298+
func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(%arg0: tensor<32x64xf32>) -> tensor<4x4x16x8xf32> {
1299+
%empty = tensor.empty() : tensor<4x4x16x8xf32>
13001300
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1301-
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1302-
return %pack : tensor<8x4x16x8xf32>
1301+
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<4x4x16x8xf32>
1302+
return %pack : tensor<4x4x16x8xf32>
13031303
}
1304-
// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
1304+
// CHECK-LABEL: func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(
13051305
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1306-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
1307-
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
1308-
// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1309-
// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
1306+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x64x8xf32>
1307+
// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
13101308
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
1311-
// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1312-
// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
1309+
// CHECK-SAME: : tensor<32x64xf32> -> tensor<4x64x8xf32>
1310+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
1311+
// CHECK-SAME: output_shape [4, 4, 16, 8] : tensor<4x64x8xf32> into tensor<4x4x16x8xf32>
1312+
// CHECK: return %[[EXPANDED]] : tensor<4x4x16x8xf32>
13131313

13141314
// -----
13151315

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
17601760
}
17611761

17621762
// -----
1763+
17631764
func.func @pack_mismatch_inner_tile_size_and_output_shape(
17641765
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
17651766
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1834,17 +1835,17 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
18341835

18351836
// -----
18361837

1837-
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1838-
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1839-
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1840-
return %0 : tensor<8x8x32x16xf32>
1838+
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
1839+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
1840+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
1841+
return %0 : tensor<8x7x16x32xf32>
18411842
}
18421843

18431844
// -----
18441845

1845-
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
1846-
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1847-
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1846+
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
1847+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
1848+
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
18481849
return %0 : tensor<256x128xf32>
18491850
}
18501851

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
326326
// -----
327327

328328
// CHECK-LABEL: func.func @pack_with_pad(
329-
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
330-
-> tensor<265x16x16x1xf32> {
329+
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
330+
-> tensor<265x12x16x1xf32> {
331331
// CHECK: tensor.pad {{.*}} low[0, 0]
332-
// CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
332+
// CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
333333
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
334-
// CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
334+
// CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
335335
// CHECK: linalg.transpose
336-
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
337-
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
336+
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
337+
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
338338
// CHECK-SAME: permutation = [0, 2, 1, 3]
339339
%cst = arith.constant 0.000000e+00 : f32
340340
%0 = linalg.pack %src
341341
padding_value(%cst : f32)
342342
inner_dims_pos = [0, 1]
343343
inner_tiles = [16, 1] into %dest
344-
: tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
345-
return %0 : tensor<265x16x16x1xf32>
344+
: tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
345+
return %0 : tensor<265x12x16x1xf32>
346346
}
347347

348348
module attributes {transform.with_named_sequence} {

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
5959

6060
// -----
6161

62-
func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
62+
func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
6363
%c0 = arith.constant 0 : index
6464
%cst = arith.constant 0.000000e+00 : f32
65-
%padded = tensor.pad %src low[0, 0] high[15, 0] {
65+
%padded = tensor.pad %src low[0, 0] high[7, 0] {
6666
^bb0(%arg0: index, %arg1: index):
6767
tensor.yield %cst : f32
68-
} : tensor<16641x16xf32> to tensor<16656x16xf32>
68+
} : tensor<16649x16xf32> to tensor<16656x16xf32>
6969
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
7070
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
7171
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +81,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
8181

8282
// -----
8383

84-
func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
84+
func.func @nofold_pad_pack_extra_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
8585
%c0 = arith.constant 0 : index
8686
%cst = arith.constant 0.000000e+00 : f32
87-
%padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
87+
%padded = tensor.pad %src low[0, 0] high[15, 0] {
8888
^bb0(%arg0: index, %arg1: index):
8989
tensor.yield %cst : f32
9090
} : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +93,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
9393
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
9494
return %pack : tensor<2082x1x8x32xf32>
9595
}
96-
// CHECK-LABEL: func.func @nofold_pad_pack
96+
// CHECK-LABLE: func.func @nofold_pad_pack_extra_padding(
97+
// CHECK: tensor.pad
98+
// CHECK: linalg.pack
99+
100+
// -----
101+
102+
func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
103+
%c0 = arith.constant 0 : index
104+
%cst = arith.constant 0.000000e+00 : f32
105+
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
106+
^bb0(%arg0: index, %arg1: index):
107+
tensor.yield %cst : f32
108+
} : tensor<16649x16xf32> to tensor<16656x16xf32>
109+
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
110+
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
111+
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
112+
return %pack : tensor<2082x1x8x32xf32>
113+
}
114+
// CHECK-LABEL: func.func @nofold_pad_pack(
97115
// CHECK: tensor.pad
98116
// CHECK: linalg.pack
99117

0 commit comments

Comments
 (0)