Skip to content

[mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… #151485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2025

Conversation

mahabadm
Copy link
Contributor

…tNull

This patch enables (de)serialization to/from OpConstantNull for null TensorARM

…tNull

This patch enables (de)serialization to/from OpConstantNull for null TensorARM

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2025

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

…tNull

This patch enables (de)serialization to/from OpConstantNull for null TensorARM


Full diff: https://github.com/llvm/llvm-project/pull/151485.diff

3 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+11-2)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29-5)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+28)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
-                     "OpConstantNull must have type <id> and result <id>");
+                     "OpConstantNull must only have type <id> and result <id>");
   }
 
   Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
+  Attribute attr;
   if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
-    auto attr = opBuilder.getZeroAttr(resultType);
+    attr = opBuilder.getZeroAttr(resultType);
+  } else if (isa<TensorArmType>(resultType)) {
+    auto shapedType = cast<ShapedType>(resultType);
+    auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+    if (element)
+      attr = DenseElementsAttr::get(shapedType, element);
+  }
+
+  if (attr) {
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
     constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
   return block;
 }
 
+static bool isNull(Attribute attr) {
+  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+    return floatAttr.getValue().isZero();
+  }
+  if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+    return !boolAttr.getValue();
+  }
+  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+    return intAttr.getValue().isZero();
+  }
+  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+    return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+  }
+  return false;
+}
+
 namespace mlir {
 namespace spirv {
 
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
       return 0;
     }
   } else if (isa<spirv::TensorArmType>(constType)) {
+    if (isNull(valueAttr)) {
+      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                            {typeID, resultID});
+      return resultID;
+    }
     numberOfConstituents = shapedType.getNumElements();
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   uint32_t resultID = getNextID();
-  uint32_t operands[] = {typeID, resultID, constandID};
-
-  encodeInstructionInto(typesGlobalValues,
-                        spirv::Opcode::OpConstantCompositeReplicateEXT,
-                        operands);
+  if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+  } else {
+    encodeInstructionInto(typesGlobalValues,
+                          spirv::Opcode::OpConstantCompositeReplicateEXT,
+                          {typeID, resultID, constandID});
+  }
 
   constCompositeReplicateIDMap[valueTypePair] = resultID;
   return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @null_arm_tensor_of_i32
+  spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
+  // CHECK-LABEL: @null_arm_tensor_of_f32
+  spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+
   spirv.EntryPoint "GLCompute" @bool_const
 }
 
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+  spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
+
+  // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+  spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
 }

@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

…tNull

This patch enables (de)serialization to/from OpConstantNull for null TensorARM


Full diff: https://github.com/llvm/llvm-project/pull/151485.diff

3 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+11-2)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29-5)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+28)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
-                     "OpConstantNull must have type <id> and result <id>");
+                     "OpConstantNull must only have type <id> and result <id>");
   }
 
   Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
+  Attribute attr;
   if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
-    auto attr = opBuilder.getZeroAttr(resultType);
+    attr = opBuilder.getZeroAttr(resultType);
+  } else if (isa<TensorArmType>(resultType)) {
+    auto shapedType = cast<ShapedType>(resultType);
+    auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+    if (element)
+      attr = DenseElementsAttr::get(shapedType, element);
+  }
+
+  if (attr) {
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
     constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
   return block;
 }
 
+static bool isNull(Attribute attr) {
+  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+    return floatAttr.getValue().isZero();
+  }
+  if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+    return !boolAttr.getValue();
+  }
+  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+    return intAttr.getValue().isZero();
+  }
+  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+    return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+  }
+  return false;
+}
+
 namespace mlir {
 namespace spirv {
 
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
       return 0;
     }
   } else if (isa<spirv::TensorArmType>(constType)) {
+    if (isNull(valueAttr)) {
+      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                            {typeID, resultID});
+      return resultID;
+    }
     numberOfConstituents = shapedType.getNumElements();
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   uint32_t resultID = getNextID();
-  uint32_t operands[] = {typeID, resultID, constandID};
-
-  encodeInstructionInto(typesGlobalValues,
-                        spirv::Opcode::OpConstantCompositeReplicateEXT,
-                        operands);
+  if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+  } else {
+    encodeInstructionInto(typesGlobalValues,
+                          spirv::Opcode::OpConstantCompositeReplicateEXT,
+                          {typeID, resultID, constandID});
+  }
 
   constCompositeReplicateIDMap[valueTypePair] = resultID;
   return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @null_arm_tensor_of_i32
+  spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
+  // CHECK-LABEL: @null_arm_tensor_of_f32
+  spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+
   spirv.EntryPoint "GLCompute" @bool_const
 }
 
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+  spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
+
+  // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+  spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
 }

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 1, 2025

@kuhar Thanks for your approval. I appreciate it if you please merge the PR as I don't have permission to merge.

@kuhar kuhar merged commit 48ea3e4 into llvm:main Aug 1, 2025
9 checks passed
@mahabadm mahabadm deleted the tensorArm_serialization_to_OpConstantNull branch August 1, 2025 12:37
kuhar added a commit that referenced this pull request Aug 4, 2025
Fix forward a few issues instead of reverting recent PRs....

1. Remove split so that the output is properly serialized. This got
added in #145687.
2. Add missing extensions and capabilities so that the spirv-val passes
(tensors_arm, linkage).
3. Disable spirv-val test for arm tensor constants. These fail to
verify. Added in #151485.

Issue: #152012
krishna2803 pushed a commit to krishna2803/llvm-project that referenced this pull request Aug 12, 2025
llvm#151485)

…tNull

This patch enables (de)serialization to/from OpConstantNull for null
TensorARM

---------

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
krishna2803 pushed a commit to krishna2803/llvm-project that referenced this pull request Aug 12, 2025
Fix forward a few issues instead of reverting recent PRs....

1. Remove split so that the output is properly serialized. This got
added in llvm#145687.
2. Add missing extensions and capabilities so that the spirv-val passes
(tensors_arm, linkage).
3. Disable spirv-val test for arm tensor constants. These fail to
verify. Added in llvm#151485.

Issue: llvm#152012
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants