Skip to content

[mlir][linalg] Restrict linalg.pack to not have artificial padding. #149624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also update:

     - The following relationship for the tiled dimensions holds:
     `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.

as

     - The following relationship for the tiled dimensions holds:
     `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] ⌈/⌉ inner_tiles[i]` (⌈/⌉ - CeilDiv).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated. Thanks for pointing it out!

Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
result tensor in the order in which they appear, i.e.
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
- The following relationship for the tiled dimensions holds:
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
where (⌈/⌉ indicates CeilDiv).


Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
Expand Down Expand Up @@ -150,9 +152,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [

`padding_value` specifies a padding value at the boundary on non-perfectly
divisible dimensions. Padding is optional:
- If absent, it is UB if the tile does not perfectly divide the dimension.
- If absent, it is assumed that for all inner tiles,
`shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
tiles divide perfectly the corresponding outer dimension in the result
tensor. It is UB if the tile does not perfectly divide the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
tile complete.
tile complete. Note that it is not allowed to have artificial padding that
is not strictly required by linalg.pack (i.e., padding past what is needed
to complete the last tile along each packed dimension). It is UB if extra
padding is requested.
Comment on lines +162 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is UB if extra padding is requested.

Shouldn't that be verification error? And then restore UB for the previous point, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be verification error?

It's not possible to enforce that with dynamic source.
UB is more of "catch all" here and allows linalg::lowerPack to remain as is.

restore UB for the previous point

It could remain there to reinforce the message. But no strong preference here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforcing that is more like a dynamic check, but not a static check. I.e., it can only be asserted during runtime for most cases, IMO. E.g., you don't do any out-of-bound checks for tensor.extract_slice for dynamic sizes/dims, but you do out-of-bound checks when the op has static sizes and offsets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored the UB for the previous point, because it also happens in dynamic shape that I forgot. I also added one more sentence about why they are UB. Does it look better?

It is not possible to verify the requirements statically with dynamic
shapes, so they are treated as UB.

Example:
```mlir
Expand All @@ -167,6 +177,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
//
// Note: Only tiled dimensions can be padded.
```

Invalid example that has artificial padding:
```mlir
%0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
inner_tiles = [8] into %dest
: tensor<9xf32> -> tensor<3x8xf32>
// \
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
Expand Down
28 changes: 7 additions & 21 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

Expand Down Expand Up @@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}

/// Returns true if the dimension of `sourceShape` is smaller than the dimension
/// of the `limitShape`.
static bool areAllInBound(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> limitShape) {
assert(
sourceShape.size() == limitShape.size() &&
"expected source shape rank, and limit of the shape to have same rank");
return llvm::all_of(
llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
int64_t sourceExtent = std::get<0>(it);
int64_t limit = std::get<1>(it);
return ShapedType::isDynamic(sourceExtent) ||
ShapedType::isDynamic(limit) || sourceExtent <= limit;
});
}

template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
Expand Down Expand Up @@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
return op->emitError("the shape of output is not large enough to hold the "
"packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
}
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
Expand All @@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
packedType.getShape()))) {
return op->emitError("expected ")
<< expectedPackedType << " for the packed domain value, got "
<< packedType;
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
Expand Down
25 changes: 13 additions & 12 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1387,51 +1387,52 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map

// -----

//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
return %0 : tensor<8x16x8x32xf32>
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
return %0 : tensor<4x8x8x32xf32>
}

// -----

// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
return %0 : tensor<8x16x8x32xf32>
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
return %0 : tensor<4x8x8x32xf32>
}


// -----

// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
return %0 : tensor<8x16x8x32xf32>
into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
return %0 : tensor<4x8x8x32xf32>
}

// -----
Expand Down
18 changes: 0 additions & 18 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(

// -----

func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
%empty = tensor.empty() : tensor<8x4x16x8xf32>
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
return %pack : tensor<8x4x16x8xf32>
}
// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>

// -----

func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
Expand Down
37 changes: 29 additions & 8 deletions mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}

// -----

func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
Expand Down Expand Up @@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t

// -----

func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
%cst = arith.constant 0.0 : f32
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
inner_tiles = [8] into %output
: tensor<9xf32> -> tensor<3x8xf32>
return %0 : tensor<3x8xf32>
}

// -----

// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}

// -----

func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
return %0 : tensor<8x8x32x16xf32>
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
return %0 : tensor<8x7x16x32xf32>
}

// -----

func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
: tensor<3x8xf32> -> tensor<9xf32>
return %0 : tensor<9xf32>
}

// -----

func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----

// CHECK-LABEL: func.func @pack_with_pad(
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
-> tensor<265x16x16x1xf32> {
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
-> tensor<265x12x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
// CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
// CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
// CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
// CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
// CHECK: linalg.transpose
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
: tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
return %0 : tensor<265x16x16x1xf32>
: tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
return %0 : tensor<265x12x16x1xf32>
}

module attributes {transform.with_named_sequence} {
Expand Down
81 changes: 0 additions & 81 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -646,87 +646,6 @@ module attributes {transform.with_named_sequence} {

// -----

// It is valid to fuse the pack if the dimension is not tiled even when it needs
// extra padding.

func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
}
}
%1 = tensor.empty() : tensor<33x2x3x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
return %pack : tensor<33x2x3x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
// CHECK: %[[ELEM:.*]] = linalg.exp
// CHECK-SAME: ins(%[[ELEM_SRC]]
// CHECK-SAME: outs(%[[ELEM_DEST]]
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
// CHECK-SAME: into %[[TILED_PACK_DEST]]
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]

// -----

// If the dimension is tiled and it needs extra padding, do not fuse the pack
// op.

func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
scf.forall.in_parallel {
// expected-error @below {{failed to fuse consumer of slice}}
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
}
}
%1 = tensor.empty() : tensor<23x32x3x16xf32>
%cst = arith.constant 0.000000e+00 : f32
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
return %pack : tensor<23x32x3x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

// Imperfect tiling is not supported in pack op consumer fusion.

#map = affine_map<(d0) -> (d0 * 5)>
Expand Down
Loading