Skip to content

[mlir][vector][memref] Add alignment attribute to memory access ops #144344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025

Conversation

tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jun 16, 2025

Alignment information is important to allow LLVM backends such as AMDGPU
to select wide memory accesses (e.g., dwordx4 or b128). Since this info
is not always inferable, it's better to inform LLVM backends explicitly
about it. Furthermore, alignment is not necessarily a property of the
element type, but of each individual memory access op (we can have
overaligned and underaligned accesses compared to the natural/preferred
alignment of the element type).

This patch introduces alignment attribute to memref/vector.load/store
ops.

Follow-up PRs will

  1. Propagate the attribute to LLVM/SPIR-V.

  2. Introduce alignment attribute to other vector memory access ops:
    vector.gather + vector.scatter
    vector.transfer_read + vector.transfer_write
    vector.compressstore + vector.expandload
    vector.maskedload + vector.maskedstore

  3. Replace --convert-vector-to-llvm='use-vector-alignment=1 with a
    simple pass to populate alignment attributes based on the vector
    types.

@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: None (tyb0807)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144344.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-3)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+48-2)
  • (added) mlir/test/Dialect/MemRef/load-store-alignment.mlir (+27)
  • (added) mlir/test/Dialect/Vector/load-store-alignment.mlir (+27)
  • (modified) mlir/unittests/Dialect/CMakeLists.txt (+1)
  • (modified) mlir/unittests/Dialect/MemRef/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp (+88)
  • (added) mlir/unittests/Dialect/Vector/CMakeLists.txt (+7)
  • (added) mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp (+95)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 77e3074661abf..160b04e452c5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1227,7 +1227,45 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memref, indices, false, alignment);
+    }]>,
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1924,13 +1962,30 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..3cd71491bcc04 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1739,7 +1739,34 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1825,9 +1852,28 @@ def Vector_StoreOp : Vector_Op<"store"> {
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment
   );
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>
+  ];
+
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f5a5461e0ac0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i32}
+// CHECK: memref.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  memref.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f54d989dd190
--- /dev/null
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: vector.load {{.*}} {alignment = 16 : i32}
+// CHECK: vector.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index aea247547473d..34c9fb7317443 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(SPIRV)
 add_subdirectory(SMT)
 add_subdirectory(Transform)
 add_subdirectory(Utils)
+add_subdirectory(Vector)
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..87d33854fadcd 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LoadStoreAlignment.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..f0b8e93c2d0e1
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
@@ -0,0 +1,88 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}
diff --git a/mlir/unittests/Dialect/Vector/CMakeLists.txt b/mlir/unittests/Dialect/Vector/CMakeLists.txt
new file mode 100644
index 0000000000000..b23d9c2df3870
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRVectorTests
+  LoadStoreAlignment.cpp
+)
+mlir_target_link_libraries(MLIRVectorTests
+  PRIVATE
+  MLIRVectorDialect
+  )
diff --git a/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..745dd8632fe4d
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
@@ -0,0 +1,95 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}

@tyb0807 tyb0807 requested review from kuhar and ftynse June 16, 2025 12:57
@krzysz00 krzysz00 changed the title Add attribute to MemRef/Vector memory access ops Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more documentation and lit tests, doesn't need the unit tests

I'd be OK with updating the LLVM (and SPIR-V, if applicable) lowerings in this patch or in a stacked-on followup

@kuhar kuhar changed the title Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops Jun 16, 2025
@kuhar kuhar changed the title [mlir][vector][memref] Add alignment attribute to MemRef/Vector memory access ops [mlir][vector][memref] Add alignment attribute to memory access ops Jun 16, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this, this will be very useful to have in the IREE codegen.

This also needs llvm/spirv lowering changes and lit tests. We don't need unit tests. See https://mlir.llvm.org/getting_started/TestingGuide/#test-categories

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 16, 2025

Thanks for the review. Actually, I already have all this covered in lit tests. I just wanted to make sure the new builders work as intended. I guess I can just remove the unit tests?

@tyb0807 tyb0807 requested review from krzysz00 and kuhar June 16, 2025 23:54
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation? That's the approach that Triton took (AxisInfo.cpp).

@tyb0807
Copy link
Contributor Author

tyb0807 commented Jun 17, 2025

Indeed, but I'm not sure if we can always infer the alignment of load/store op solely from the indexing maths. In case where this is not possible, we would need a way (less automatic) to specify this constraint, right?

@matthias-springer
Copy link
Member

Can you show an example where that would not work?

@banach-space
Copy link
Contributor

Do you ever expect to need something like this (different alignment for different Ops):

func.func @test_load_store_alignment(%memref: memref<4xi32>) {
  %c0 = arith.constant 0 : index
  %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
  vector.store %val, %memref[%c0] { alignment = 32 } : memref<4xi32>, vector<4xi32>
  return
}

I am just wondering, why do we need to "decorate" every Op with this attribute? And what logic is meant to take care of it? Why couldn't the alignment be a parameter that's passed to e.g. a conversion pass?

@matthias-springer
Copy link
Member

Alignment is a property of the memref SSA value. But we don't encode it in the memref type. We have memref.assume_alignment as a way to attach alignment information to an SSA value. The alignment can then be queried by a dataflow analysis.

There are two alternatives to this approach:

  1. Make the alignment information part of the memref type.
  2. Add attributes to each load/store op. (That's what this PR is doing.)

I'd like to make sure that we have a consistent story for dealing with alignment. Having both memref.assume_alignment and attributes on various ops seems a bit odd...

@kuhar
Copy link
Member

kuhar commented Jun 17, 2025

Is it possible to build an analysis based on memref.assume_alignment instead of adding an alignment attribute to every load/store operation?

Alignment is a property of the memref SSA value. But we don't encode it in the memref type.

I don't think this is the case. You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op. You may end up with a memref of bytes as you lower and merge allocations etc. This is also the case with lower level IRs like llvm or spirv, e.g.: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Memory_Operands.

@matthias-springer
Copy link
Member

Good point, looks like we may want to have it on the load/store ops.

@ftynse
Copy link
Member

ftynse commented Jun 18, 2025

You can have a memref of ?xi8 that doesn't have any inherent static alignment and the alignment is really a property at each load/store op.

At that point, maybe you should use !ptr.ptr instead and not a memref. Not necessarily opposing this change, but I don't want to blindly replicate notions from lower-level abstractions like LLVM IR and SPIR-V to a higher-level abstraction.

A stronger argument may be that alignment goes both ways and we can have overaligned and underaligned accesses compared to the natural/preferred alignment of the element type, and those should be reflected somewhere, which is not necessarily a property of the type. Underaligned accesses are more interesting because those may be an optimization hint (aligned accesses are faster) or plainly forbidden by the architecture.

Note that the attribute approach is not precluding a dataflow analysis. We can have an analysis that propagates alignment information to individual operations, e.g., by looking at the structure of subscripts and attributes on previous operations accessing the same value. Attributes can be seen as a way to preserve analysis results.

Good point, looks like we may want to have it on the load/store ops.

Should we also remove memref.assume_alignment? This operation is rather confusing because nothing precludes one from using it repeatedly on the same value and the fact that it is side-effecting (so DCE doesn't remove it) without actually having side effects has been pointed out.

@krzysz00
Copy link
Contributor

Yeah, I remember when landing --convert-vector-to-llvm='use-vector-alignment=1 I had the sense that it was a temporary solution to work around the fact that we couldn't put alignment on vector loads/stores.

I'd argue for removing it in favor of some flavor of vector-declare-natural-alignment transform over the vector dialect to reduce that same redundancy.

memref.assume_alignment is a rather weird op that, as far as I can tell, exists to allow backends that try to reason about pointer alignment to not stumble over the mysterious pointer-out-of-nowhere that can exist inside memrefs's base.

@banach-space banach-space requested a review from dcaballe June 24, 2025 12:46
@banach-space
Copy link
Contributor

The alignment attribute is general enough to subsume use-vector-alignment -- we can write a simple pass to populate alignment attributes based on the vector types. So in this sense, we may be able to reduce feature duplication in the future.

I'd argue for removing it in favor of some flavor of vector-declare-natural-alignment transform over the vector dialect to reduce that same redundancy.

+1 Lets make sure that there is a clear TODO to that end.

Btw, there is more Ops that access memory:

  • vector.gather + vector.scatter
  • vector.transfer_read + vector.transfer_write
  • vector.compressstore + vector.expandload

What about these?

@krzysz00
Copy link
Contributor

Yeah, we probably should get them all - and probably hit them all with nontemporal too while we're here

@electriclilies
Copy link
Contributor

Hi! Author of #137389 here :)

Yes, it was a workaround-- we were getting perf issues from vectors being aligned to scalar alignments. I think an alignment attribute is a better long term solution, and I'm all for getting rid of the --convert-vector-to-llvm='use-vector-alignment=1 as long as we're still able to set vector alignments properly. Also, we do have use cases for setting the alignment higher up in the stack, for example during memory planning.

One concern I have with using memref ops for this is that they exist in so many different places in the stack. I could see this becoming messy-- let's say I have a memory planner which tries to pick some alignments, then everything gets lowered into LLVMIR, and then LLVM passes also try to set alignment. It's not clear to me if we should let lower level passes "override" alignment set by higher level passes, or vice versa. And this gets complicated as we interleave default llvm passes with custom ones. It might be worthwhile to expose an interface to let users override how default passes set alignment without making an upstream change. Just something to think about as you start implementing logic to set the alignments.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM aside from the minor test relocation nits

@banach-space
Copy link
Contributor

Yeah, we probably should get them all

Lets do this here, to avoid a split state where some Ops support it while others don't.

@kuhar
Copy link
Member

kuhar commented Jun 26, 2025

Yeah, we probably should get them all

Lets do this here, to avoid a split state where some Ops support it while others don't.

I'd be -1 on this -- I think this PR is already self contained, and similar to the nontemportal support, we can extend higher-level ops separately. Smaller PRs are easier to review, and landing this bottom-up makes sense to me.

@banach-space
Copy link
Contributor

I'd be -1 on this -- I think this PR is already self contained, and similar to the nontemportal support, we can extend higher-level ops separately.

I see this differently.

This is a non-blocker for me, but please make sure that summary provides a justification for not updating all the Ops. Lets also document the next steps for --convert-vector-to-llvm='use-vector-alignment=1 in the summary. Thanks!

banach-space added a commit that referenced this pull request Jun 27, 2025
…aliases (#145235)

This patch adds additional checks to the hoisting logic to prevent hoisting of
`vector.transfer_read` / `vector.transfer_write` pairs when the underlying
memref has users that introduce aliases via operations implementing
`ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities and could
affect performance. However, as demonstrated by the included tests, the current
logic is too permissive and can lead to incorrect transformations.

If this change prevents hoisting in cases that are provably safe, please share
a minimal repro - I'm happy to explore ways to relax the check.

Special treatment is given to `memref.assume_alignment`, mainly to accommodate
recent updates in:

* #139521

Note that such special casing does not scale and should generally be avoided.
The current hoisting logic lacks robust alias analysis. While better support
would require more work, the broader semantics of `memref.assume_alignment`
remain somewhat unclear. It's possible this op may eventually be replaced with
the "alignment" attribute added in:

* #144344
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 27, 2025
…resence of aliases (#145235)

This patch adds additional checks to the hoisting logic to prevent hoisting of
`vector.transfer_read` / `vector.transfer_write` pairs when the underlying
memref has users that introduce aliases via operations implementing
`ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities and could
affect performance. However, as demonstrated by the included tests, the current
logic is too permissive and can lead to incorrect transformations.

If this change prevents hoisting in cases that are provably safe, please share
a minimal repro - I'm happy to explore ways to relax the check.

Special treatment is given to `memref.assume_alignment`, mainly to accommodate
recent updates in:

* llvm/llvm-project#139521

Note that such special casing does not scale and should generally be avoided.
The current hoisting logic lacks robust alias analysis. While better support
would require more work, the broader semantics of `memref.assume_alignment`
remain somewhat unclear. It's possible this op may eventually be replaced with
the "alignment" attribute added in:

* llvm/llvm-project#144344
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
…aliases (llvm#145235)

This patch adds additional checks to the hoisting logic to prevent hoisting of
`vector.transfer_read` / `vector.transfer_write` pairs when the underlying
memref has users that introduce aliases via operations implementing
`ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities and could
affect performance. However, as demonstrated by the included tests, the current
logic is too permissive and can lead to incorrect transformations.

If this change prevents hoisting in cases that are provably safe, please share
a minimal repro - I'm happy to explore ways to relax the check.

Special treatment is given to `memref.assume_alignment`, mainly to accommodate
recent updates in:

* llvm#139521

Note that such special casing does not scale and should generally be avoided.
The current hoisting logic lacks robust alias analysis. While better support
would require more work, the broader semantics of `memref.assume_alignment`
remain somewhat unclear. It's possible this op may eventually be replaced with
the "alignment" attribute added in:

* llvm#144344
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the summary, that makes a lot of sense to me 🙏🏻

  1. Retire memref.assume_alignment op.

I'm not quite sure whether we have agreement on this? I do support it though and am happy for you to keep it there.

All in all, LGTM, but I left a few nits to improve our testing "hygiene". Nothing major. I will "approve" once that's addressed.

Thanks for working on this!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! I've left a couple of nits, but otherwise LGTM, so approving as is (but please address my comments before merging).

One side note, please note that rebasing and force-pushing is discouraged in LLVM:

While rebase is sometimes unavoidable, please avoid squashing commits in a PR - that obfuscates the history and makes reviewing harder. Thanks!

tyb0807 added 2 commits July 18, 2025 00:06
…store

Alignment information is important to allow LLVM backends such as AMDGPU
to select wide memory accesses (e.g., dwordx4 or b128). Since this info
is not always inferable, it's better to inform LLVM backends explicitly
about it. Furthermore, alignment is not necessarily a property of the
element type, but of each individual memory access op (we can have
overaligned and underaligned accesses compared to the natural/preferred
alignment of the element type).

This patch introduces `alignment` attribute to memref/vector.load/store
ops.

Follow-up PRs will

1. Propagate these attributes to LLVM/SPIR-V.

2. Introduce `alignment` attribute to other vector memory access ops:
    vector.gather + vector.scatter
    vector.transfer_read + vector.transfer_write
    vector.compressstore + vector.expandload
    vector.maskedload + vector.maskedstore

3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a
   simple pass to populate alignment attributes based on the vector
   types.
@kuhar kuhar merged commit aa39785 into llvm:main Jul 17, 2025
9 checks passed
@tyb0807 tyb0807 deleted the alignment branch July 17, 2025 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants