Skip to content

[mlir][vector] Add a check to ensure input vector rank equals target shape rank #149239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Jul 17, 2025

The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which is not supported. Fixes #148368.

@llvmbot
Copy link
Member

llvmbot commented Jul 17, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which is not supported. Fixed #148368.


Full diff: https://github.com/llvm/llvm-project/pull/149239.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+15-1)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 693f4f955994d..be911901c2afc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -169,7 +169,13 @@ struct UnrollTransferReadPattern
     auto sourceVectorType = readOp.getVectorType();
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = readOp.getLoc();
-    ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
+    ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `InsertStridedSliceOp` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          readOp, "expected source vector rank to match target shape rank");
 
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +230,14 @@ struct UnrollTransferWritePattern
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = writeOp.getLoc();
     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
+    if (originalSize.size() != targetShape->size())
+      return rewriter.notifyMatchFailure(
+          writeOp,
+          "expected source input vector rank to match target shape rank");
+
     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
                                        writeOp.getIndices().end());
     SmallVector<int64_t> loopOrder =
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e129cd5c40b9c..8f6945468feb3 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -420,3 +420,25 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
   // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
   // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
   // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
+
+func.func @vector_transfer_read(%arg0: memref<6x34x62xi8>) -> vector<6x34x62xi8> {
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %c0_i8 : memref<6x34x62xi8>, vector<6x34x62xi8>
+  return %0 : vector<6x34x62xi8>
+}
+// CHECK-LABEL: func @vector_transfer_read
+//   CHECK-NOT: vector.intert_strided_slice
+//       CHECK: vector.transfer_read
+//       CHECK: return
+
+func.func @vector_transfer_write(%arg0: vector<6x34x62xi8>) {
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() : memref<6x34x62xi8>
+  vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
+  return
+}
+// CHECK-LABEL: func @vector_transfer_write
+//   CHECK-NOT: vector.extract_strided_slice
+//       CHECK: vector.transfer_write
+//       CHECK: return

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for providing a fix!

@@ -420,3 +420,25 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a file vector-transfer-unroll.mlir which might be a better place for these tests.

Please add comments above the tests describing what is being tested.

Such tests generally are prefixed with negative_

ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
// Bail-out if rank(source) != rank(target). The main limitation here is the
// fact that `InsertStridedSliceOp` requires the rank for the input and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

…shape rank

The crash is caused because, during IR transformation, the
vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an
input vector of higher rank using a target vector of lower rank, which
is not supported.
@newling
Copy link
Contributor

newling commented Jul 23, 2025

Hi @CoTinker thanks for changes. Just to note: in general it is preferable to not squash commits, it is easier for reviewers to follow what's happening if the commits are kept. When PR's are merged, the commits get squashed into 1 by default, but in the review process it's nice to see the incremental changes. Just for future reference. Thanks!

@CoTinker
Copy link
Contributor Author

Hi @CoTinker thanks for changes. Just to note: in general it is preferable to not squash commits, it is easier for reviewers to follow what's happening if the commits are kept. When PR's are merged, the commits get squashed into 1 by default, but in the review process it's nice to see the incremental changes. Just for future reference. Thanks!

I see, thanks for your advice.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % request for a comment in test. Thanks!

@newling, are you OK with this change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants