Skip to content

Commit 5adae03

Browse files
committed
[mlir][memref][vector] Add alignment attribute to memref/vector.load/store
Alignment information is important to allow LLVM backends such as AMDGPU to select wide memory accesses (e.g., dwordx4 or b128). Since this info is not always inferable, it's better to inform LLVM backends explicitly about it. Furthermore, alignment is not necessarily a property of the element type, but of each individual memory access op (we can have overaligned and underaligned accesses compared to the natural/preferred alignment of the element type). This patch introduces `alignment` attribute to memref/vector.load/store ops. Follow-up PRs will 1. Introduce `alignment` attribute to other vector memory access ops: vector.gather + vector.scatter vector.transfer_read + vector.transfer_write vector.compressstore + vector.expandload vector.maskedload + vector.maskedstore 2. Propagate these attributes to LLVM/SPIR-V. 3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a simple pass to populate alignment attributes based on the vector types. 4. Retire `memref.assume_alignment` op.
1 parent 6a94814 commit 5adae03

File tree

8 files changed

+173
-6
lines changed

8 files changed

+173
-6
lines changed

mlir/docs/DefiningDialects/Operations.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported:
306306
* `IntPositive`: Specifying an integer attribute whose value is positive
307307
* `IntNonNegative`: Specifying an integer attribute whose value is
308308
non-negative
309+
* `IntPowerOf2`: Specifying an integer attribute whose value is a power of
310+
two > 0
309311
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
310312
elements
311313
* `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,11 @@ def LoadOp : MemRef_Op<"load",
12161216
be reused in the cache. For details, refer to the
12171217
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
12181218

1219+
An optional `alignment` attribute allows to specify the byte alignment of the
1220+
load operation. It must be a positive power of 2. The operation must access
1221+
memory at an address aligned to this boundary. Violations may lead to
1222+
architecture-specific faults or performance penalties.
1223+
A value of 0 indicates no specific alignment requirement.
12191224
Example:
12201225

12211226
```mlir
@@ -1226,7 +1231,39 @@ def LoadOp : MemRef_Op<"load",
12261231
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
12271232
[MemRead]>:$memref,
12281233
Variadic<Index>:$indices,
1229-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1234+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1235+
ConfinedAttr<OptionalAttr<I64Attr>,
1236+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1237+
1238+
let builders = [
1239+
OpBuilder<(ins "Value":$memref,
1240+
"ValueRange":$indices,
1241+
CArg<"bool", "false">:$nontemporal,
1242+
CArg<"uint64_t", "0">:$alignment), [{
1243+
return build($_builder, $_state, memref, indices, nontemporal,
1244+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1245+
nullptr);
1246+
}]>,
1247+
OpBuilder<(ins "Type":$resultType,
1248+
"Value":$memref,
1249+
"ValueRange":$indices,
1250+
CArg<"bool", "false">:$nontemporal,
1251+
CArg<"uint64_t", "0">:$alignment), [{
1252+
return build($_builder, $_state, resultType, memref, indices, nontemporal,
1253+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1254+
nullptr);
1255+
}]>,
1256+
OpBuilder<(ins "TypeRange":$resultTypes,
1257+
"Value":$memref,
1258+
"ValueRange":$indices,
1259+
CArg<"bool", "false">:$nontemporal,
1260+
CArg<"uint64_t", "0">:$alignment), [{
1261+
return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
1262+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1263+
nullptr);
1264+
}]>
1265+
];
1266+
12301267
let results = (outs AnyType:$result);
12311268

12321269
let extraClassDeclaration = [{
@@ -1912,6 +1949,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
19121949
be reused in the cache. For details, refer to the
19131950
[https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
19141951

1952+
An optional `alignment` attribute allows to specify the byte alignment of the
1953+
store operation. It must be a positive power of 2. The operation must access
1954+
memory at an address aligned to this boundary. Violations may lead to
1955+
architecture-specific faults or performance penalties.
1956+
A value of 0 indicates no specific alignment requirement.
19151957
Example:
19161958

19171959
```mlir
@@ -1923,13 +1965,25 @@ def MemRef_StoreOp : MemRef_Op<"store",
19231965
Arg<AnyMemRef, "the reference to store to",
19241966
[MemWrite]>:$memref,
19251967
Variadic<Index>:$indices,
1926-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1968+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1969+
ConfinedAttr<OptionalAttr<I64Attr>,
1970+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
19271971

19281972
let builders = [
1973+
OpBuilder<(ins "Value":$valueToStore,
1974+
"Value":$memref,
1975+
"ValueRange":$indices,
1976+
CArg<"bool", "false">:$nontemporal,
1977+
CArg<"uint64_t", "0">:$alignment), [{
1978+
return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
1979+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1980+
nullptr);
1981+
}]>,
19291982
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
19301983
$_state.addOperands(valueToStore);
19311984
$_state.addOperands(memref);
1932-
}]>];
1985+
}]>
1986+
];
19331987

19341988
let extraClassDeclaration = [{
19351989
Value getValueToStore() { return getOperand(0); }

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,12 +1809,42 @@ def Vector_LoadOp : Vector_Op<"load", [
18091809
```mlir
18101810
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
18111811
```
1812+
1813+
An optional `alignment` attribute allows to specify the byte alignment of the
1814+
load operation. It must be a positive power of 2. The operation must access
1815+
memory at an address aligned to this boundary. Violations may lead to
1816+
architecture-specific faults or performance penalties.
1817+
A value of 0 indicates no specific alignment requirement.
18121818
}];
18131819

18141820
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
18151821
[MemRead]>:$base,
18161822
Variadic<Index>:$indices,
1817-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1823+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1824+
ConfinedAttr<OptionalAttr<I64Attr>,
1825+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1826+
1827+
let builders = [
1828+
OpBuilder<(ins "VectorType":$resultType,
1829+
"Value":$base,
1830+
"ValueRange":$indices,
1831+
CArg<"bool", "false">:$nontemporal,
1832+
CArg<"uint64_t", "0">:$alignment), [{
1833+
return build($_builder, $_state, resultType, base, indices, nontemporal,
1834+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1835+
nullptr);
1836+
}]>,
1837+
OpBuilder<(ins "TypeRange":$resultTypes,
1838+
"Value":$base,
1839+
"ValueRange":$indices,
1840+
CArg<"bool", "false">:$nontemporal,
1841+
CArg<"uint64_t", "0">:$alignment), [{
1842+
return build($_builder, $_state, resultTypes, base, indices, nontemporal,
1843+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1844+
nullptr);
1845+
}]>
1846+
];
1847+
18181848
let results = (outs AnyVectorOfAnyRank:$result);
18191849

18201850
let extraClassDeclaration = [{
@@ -1895,15 +1925,34 @@ def Vector_StoreOp : Vector_Op<"store", [
18951925
```mlir
18961926
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
18971927
```
1928+
1929+
An optional `alignment` attribute allows to specify the byte alignment of the
1930+
store operation. It must be a positive power of 2. The operation must access
1931+
memory at an address aligned to this boundary. Violations may lead to
1932+
architecture-specific faults or performance penalties.
1933+
A value of 0 indicates no specific alignment requirement.
18981934
}];
18991935

19001936
let arguments = (ins
19011937
AnyVectorOfAnyRank:$valueToStore,
19021938
Arg<AnyMemRef, "the reference to store to",
19031939
[MemWrite]>:$base,
19041940
Variadic<Index>:$indices,
1905-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
1906-
);
1941+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1942+
ConfinedAttr<OptionalAttr<I64Attr>,
1943+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1944+
1945+
let builders = [
1946+
OpBuilder<(ins "Value":$valueToStore,
1947+
"Value":$base,
1948+
"ValueRange":$indices,
1949+
CArg<"bool", "false">:$nontemporal,
1950+
CArg<"uint64_t", "0">:$alignment), [{
1951+
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
1952+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1953+
nullptr);
1954+
}]>
1955+
];
19071956

19081957
let extraClassDeclaration = [{
19091958
MemRefType getMemRefType() {

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,10 @@ def IntPositive : AttrConstraint<
796796
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
797797
"whose value is positive">;
798798

799+
def IntPowerOf2 : AttrConstraint<
800+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
801+
"whose value is a power of two > 0">;
802+
799803
class ArrayMaxCount<int n> : AttrConstraint<
800804
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
801805
"with at most " # n # " elements">;

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,3 +1139,21 @@ func.func @expand_shape_invalid_output_shape(
11391139
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
11401140
return
11411141
}
1142+
1143+
// -----
1144+
1145+
func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
1146+
%c0 = arith.constant 0 : index
1147+
// expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1148+
%val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
1149+
return
1150+
}
1151+
1152+
// -----
1153+
1154+
func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: i32) {
1155+
%c0 = arith.constant 0 : index
1156+
// expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1157+
memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>
1158+
return
1159+
}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,15 @@ func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affin
613613
%dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
614614
return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
615615
}
616+
617+
// -----
618+
619+
// CHECK-LABEL: func @test_load_store_alignment
620+
// CHECK: memref.load {{.*}} {alignment = 16 : i64}
621+
// CHECK: memref.store {{.*}} {alignment = 16 : i64}
622+
func.func @test_load_store_alignment(%memref: memref<4xi32>) {
623+
%c0 = arith.constant 0 : index
624+
%val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
625+
memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
626+
return
627+
}

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,3 +2005,21 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
20052005
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
20062006
return
20072007
}
2008+
2009+
// -----
2010+
2011+
func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
2012+
%c0 = arith.constant 0 : index
2013+
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
2014+
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
2015+
return
2016+
}
2017+
2018+
// -----
2019+
2020+
func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
2021+
%c0 = arith.constant 0 : index
2022+
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
2023+
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
2024+
return
2025+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,3 +1218,13 @@ func.func @step() {
12181218
%1 = vector.step : vector<[4]xindex>
12191219
return
12201220
}
1221+
1222+
// CHECK-LABEL: func @test_load_store_alignment
1223+
func.func @test_load_store_alignment(%memref: memref<4xi32>) {
1224+
%c0 = arith.constant 0 : index
1225+
// CHECK: vector.load {{.*}} {alignment = 16 : i64}
1226+
%val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
1227+
// CHECK: vector.store {{.*}} {alignment = 16 : i64}
1228+
vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
1229+
return
1230+
}

0 commit comments

Comments
 (0)