-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… #151485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… #151485
Conversation
…tNull This patch enables (de)serialization to/from OpConstantNull for null TensorARM Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvm/pr-subscribers-mlir Author: Mohammadreza Ameri Mahabadian (mahabadm) Changes…tNull This patch enables (de)serialization to/from OpConstantNull for null TensorARM Full diff: https://github.com/llvm/llvm-project/pull/151485.diff 3 Files Affected:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
- "OpConstantNull must have type <id> and result <id>");
+ "OpConstantNull must only have type <id> and result <id>");
}
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
+ Attribute attr;
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
- auto attr = opBuilder.getZeroAttr(resultType);
+ attr = opBuilder.getZeroAttr(resultType);
+ } else if (isa<TensorArmType>(resultType)) {
+ auto shapedType = cast<ShapedType>(resultType);
+ auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+ if (element)
+ attr = DenseElementsAttr::get(shapedType, element);
+ }
+
+ if (attr) {
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
return block;
}
+static bool isNull(Attribute attr) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ return floatAttr.getValue().isZero();
+ }
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+ return !boolAttr.getValue();
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ return intAttr.getValue().isZero();
+ }
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+ }
+ return false;
+}
+
namespace mlir {
namespace spirv {
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
} else if (isa<spirv::TensorArmType>(constType)) {
+ if (isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
+ }
numberOfConstituents = shapedType.getNumElements();
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
}
uint32_t resultID = getNextID();
- uint32_t operands[] = {typeID, resultID, constandID};
-
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpConstantCompositeReplicateEXT,
- operands);
+ if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ } else {
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ {typeID, resultID, constandID});
+ }
constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @null_arm_tensor_of_i32
+ spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @null_arm_tensor_of_f32
+ spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+ spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+ spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
}
|
@llvm/pr-subscribers-mlir-spirv Author: Mohammadreza Ameri Mahabadian (mahabadm) Changes…tNull This patch enables (de)serialization to/from OpConstantNull for null TensorARM Full diff: https://github.com/llvm/llvm-project/pull/151485.diff 3 Files Affected:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
- "OpConstantNull must have type <id> and result <id>");
+ "OpConstantNull must only have type <id> and result <id>");
}
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
+ Attribute attr;
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
- auto attr = opBuilder.getZeroAttr(resultType);
+ attr = opBuilder.getZeroAttr(resultType);
+ } else if (isa<TensorArmType>(resultType)) {
+ auto shapedType = cast<ShapedType>(resultType);
+ auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+ if (element)
+ attr = DenseElementsAttr::get(shapedType, element);
+ }
+
+ if (attr) {
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
return block;
}
+static bool isNull(Attribute attr) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ return floatAttr.getValue().isZero();
+ }
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+ return !boolAttr.getValue();
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ return intAttr.getValue().isZero();
+ }
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+ }
+ return false;
+}
+
namespace mlir {
namespace spirv {
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
} else if (isa<spirv::TensorArmType>(constType)) {
+ if (isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
+ }
numberOfConstituents = shapedType.getNumElements();
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
}
uint32_t resultID = getNextID();
- uint32_t operands[] = {typeID, resultID, constandID};
-
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpConstantCompositeReplicateEXT,
- operands);
+ if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ } else {
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ {typeID, resultID, constandID});
+ }
constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @null_arm_tensor_of_i32
+ spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @null_arm_tensor_of_f32
+ spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+ spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+ spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
}
|
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar Thanks for your approval. I appreciate it if you please merge the PR as I don't have permission to merge. |
Fix forward a few issues instead of reverting recent PRs.... 1. Remove split so that the output is properly serialized. This got added in #145687. 2. Add missing extensions and capabilities so that the spirv-val passes (tensors_arm, linkage). 3. Disable spirv-val test for arm tensor constants. These fail to verify. Added in #151485. Issue: #152012
llvm#151485) …tNull This patch enables (de)serialization to/from OpConstantNull for null TensorARM --------- Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Fix forward a few issues instead of reverting recent PRs.... 1. Remove split so that the output is properly serialized. This got added in llvm#145687. 2. Add missing extensions and capabilities so that the spirv-val passes (tensors_arm, linkage). 3. Disable spirv-val test for arm tensor constants. These fail to verify. Added in llvm#151485. Issue: llvm#152012
…tNull
This patch enables (de)serialization to/from OpConstantNull for null TensorARM