diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..734a8590eedb7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -169,7 +169,7 @@ struct UnrollTransferReadPattern auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); - ArrayRef originalSize = readOp.getVectorType().getShape(); + ArrayRef originalSize = sourceVectorType.getShape(); // Prepare the result vector; Value result = rewriter.create( @@ -224,6 +224,14 @@ struct UnrollTransferWritePattern SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + writeOp, + "expected source input vector rank to match target shape rank"); + SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector loopOrder = diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index 5dd65ea132d08..c7025044a0e1b 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -68,6 +68,22 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>) // ----- +// Ensure that cases with mismatched target and source +// shape ranks do not lead to a crash. + +// CHECK-LABEL: func @negative_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return +func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} + +// ----- + // CHECK-LABEL: func @transfer_readwrite_unroll // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index