-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[DirectX] Moving Root Signature Metadata Parsing in to Shared Root Signature Metadata lib #149221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DirectX] Moving Root Signature Metadata Parsing in to Shared Root Signature Metadata lib #149221
Conversation
@llvm/pr-subscribers-hlsl Author: None (joaosaffran) ChangesThis PR, moves the existing Root Signature Metadata Parsing logic used in Patch is 51.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149221.diff 11 Files Affected:
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 0aa122f668ef1..729ea22d3c8ab 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -15,6 +15,8 @@
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/MC/DXContainerRootSignature.h"
namespace llvm {
class LLVMContext;
@@ -24,6 +26,96 @@ class Metadata;
namespace hlsl {
namespace rootsig {
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
+template <typename T>
+class RootSignatureValidationError
+ : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+ static char ID;
+ std::string ParamName;
+ T Value;
+
+ RootSignatureValidationError(StringRef ParamName, T Value)
+ : ParamName(ParamName.str()), Value(Value) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName << ": " << Value;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+ static char ID;
+ std::string Message;
+
+ GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+ void log(raw_ostream &OS) const override { OS << Message; }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+ static char ID;
+ std::string ElementName;
+
+ InvalidRSMetadataFormat(StringRef ElementName)
+ : ElementName(ElementName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid format for " << ElementName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+ static char ID;
+ std::string ParamName;
+
+ InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -49,6 +141,47 @@ class MetadataBuilder {
SmallVector<Metadata *> GeneratedMetadata;
};
+enum class RootSignatureElementKind {
+ Error = 0,
+ RootFlags = 1,
+ RootConstants = 2,
+ SRV = 3,
+ UAV = 4,
+ CBV = 5,
+ DescriptorTable = 6,
+ StaticSamplers = 7
+};
+
+class MetadataParser {
+public:
+ MetadataParser(MDNode *Root) : Root(Root) {}
+
+ /// Iterates through root signature and converts them into MapT
+ LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+ ParseRootSignature(uint32_t Version);
+
+private:
+ llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode);
+ llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode);
+ llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind);
+ llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode);
+ llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode);
+ llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element);
+ llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode);
+
+ llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
+
+ MDNode *Root;
+};
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 4b6b42f7d74f7..14a2429ffcc78 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+
#include "llvm/BinaryFormat/DXContainer.h"
#include <cstdint>
#include <limits>
@@ -116,3 +119,5 @@ struct RootSignatureDesc {
};
} // namespace mcdxbc
} // namespace llvm
+
+#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f7669f09dcecc..23c1815d438ad 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -20,6 +21,12 @@ namespace llvm {
namespace hlsl {
namespace rootsig {
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+
+template <typename T> char RootSignatureValidationError<T>::ID;
+
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
{"SRV", dxil::ResourceClass::SRV},
@@ -189,6 +196,514 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
+ if (RootFlagNode->getNumOperands() != 2)
+ return make_error<InvalidRSMetadataFormat>("RootFlag Element");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
+ RSD.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RootFlag");
+
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
+ if (RootConstantNode->getNumOperands() != 5)
+ return make_error<InvalidRSMetadataFormat>("RootConstants Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ // The parameter offset doesn't matter here - we recalculate it during
+ // serialization Header.ParameterOffset = 0;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ dxbc::RTS0::v1::RootConstants Constants;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
+ Constants.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
+ Constants.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
+ Constants.Num32BitValues = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Num32BitValues");
+
+ RSD.ParametersContainer.addParameter(Header, Constants);
+
+ return llvm::Error::success();
+}
+
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
+ assert(ElementKind == RootSignatureElementKind::SRV ||
+ ElementKind == RootSignatureElementKind::UAV ||
+ ElementKind == RootSignatureElementKind::CBV &&
+ "parseRootDescriptors should only be called with RootDescriptor"
+ "element kind.");
+ if (RootDescriptorNode->getNumOperands() != 5)
+ return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ switch (ElementKind) {
+ case RootSignatureElementKind::SRV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+ break;
+ case RootSignatureElementKind::UAV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+ break;
+ case RootSignatureElementKind::CBV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+ break;
+ default:
+ llvm_unreachable("invalid Root Descriptor kind");
+ break;
+ }
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ dxbc::RTS0::v2::RootDescriptor Descriptor;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
+ Descriptor.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
+ Descriptor.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (RSD.Version == 1) {
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return llvm::Error::success();
+ }
+ assert(RSD.Version > 1);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
+ Descriptor.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
+
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
+ if (RangeDescriptorNode->getNumOperands() != 6)
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+
+ dxbc::RTS0::v2::DescriptorRange Range;
+
+ std::optional<StringRef> ElementText =
+ extractMdStringValue(RangeDescriptorNode, 0);
+
+ if (!ElementText.has_value())
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+
+ Range.RangeType =
+ StringSwitch<uint32_t>(*ElementText)
+ .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
+ .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
+ .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
+ .Case("Sampler",
+ llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+ .Default(~0U);
+
+ if (Range.RangeType == ~0U)
+ return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+ *ElementText);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
+ Range.NumDescriptors = *Val;
+ else
+ return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
+ Range.BaseShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
+ Range.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
+ Range.OffsetInDescriptorsFromTableStart = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>(
+ "OffsetInDescriptorsFromTableStart");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
+ Range.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
+
+ Table.Ranges.push_back(Range);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
+ const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
+ if (NumOperands < 2)
+ return make_error<InvalidRSMetadataFormat>("Descriptor Table");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ mcdxbc::DescriptorTable Table;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
+
+ for (unsigned int I = 2; I < NumOperands; I++) {
+ MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
+ if (Element == nullptr)
+ return make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.");
+
+ if (auto Err = parseDescriptorRange(Table, Element))
+ return Err;
+ }
+
+ RSD.ParametersContainer.addParameter(Header, Table);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
+ if (StaticSamplerNode->getNumOperands() != 14)
+ return make_error<InvalidRSMetadataFormat>("Static Sampler");
+
+ dxbc::RTS0::v1::StaticSampler Sampler;
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
+ Sampler.Filter = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Filter");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
+ Sampler.AddressU = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressU");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
+ Sampler.AddressV = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressV");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
+ Sampler.AddressW = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressW");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
+ Sampler.MipLODBias = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MipLODBias");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
+ Sampler.MaxAnisotropy = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
+ Sampler.ComparisonFunc = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
+ Sampler.BorderColor = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
+ Sampler.MinLOD = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MinLOD");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
+ Sampler.MaxLOD = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MaxLOD");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
+ Sampler.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
+ Sampler.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
+ Sampler.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ RSD.StaticSamplers.push_back(Sampler);
+ return llvm::Error::success();
+}
+
+llvm::Error
+MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
+ std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+ if (!ElementText.has_value())
+ return make_error<InvalidRSMetadataFormat>("Root Element");
+
+ RootSignatureElementKind ElementKind =
+ StringSwitch<RootSignatureElementKind>(*ElementText)
+ .Case("RootFlags", RootSignatureElementKind::RootFlags)
+ .Case("RootConstants", RootSignatureElementKind::RootConstants)
+ .Case("RootCBV", RootSignatureElementKind::CBV)
+ .Case("RootSRV", RootSignatureElementKind::SRV)
+ .Case("RootUAV", RootSignatureElementKind::UAV)
+ .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+ .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
+ .Default(RootSignatureElementKind::Error);
+
+ switch (ElementKind) {
+
+ case RootSignatureElementKind::RootFlags:
+ return parseRootFlags(RSD, Element);
+ case RootSignatureElementKind::RootConstants:
+ return parseRootConstants(RSD, Element);
+ case RootSignatureElementKind::CBV:
+ case RootSignatureElementKind::SRV:
+ case RootSignatureElementKind::UAV:
+ return parseRootDescriptors(RSD, Element, ElementKind);
+ case RootSignatureElementKind::DescriptorTable:
+ return parseDescriptorTable(RSD, Element);
+ case RootSignatureElementKind::StaticSamplers:
+ return parseStaticSampler(RSD, Element);
+ case RootSignatureElementKind::Error:
+ return make_error<GenericRSMetadataError>(
+ "Invalid Root Signature Element:" + *ElementText);
+ }
+
+ llvm_unreachable("Unhandled RootSignatureElementKind enum.");
+}
+
+llvm::Error MetadataParser::validateRootSignature(
+ const llvm::mcdxbc::RootSignatureDesc &RSD) {
+ Error DeferredErrs = Error::success();
+ if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Version", RSD.Version));
+ }
+
+ if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootFlags", RSD.Flags));
+ }
+
+ for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
+ if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Info.Header.ShaderVisibility));
+
+ assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
+ "Invalid value for ParameterType");
+
+ switch (Info.Header.ParameterType) {
+
+ case llvm::to_underlying(dxbc::RootParameterType::CBV):
+ case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ RSD.ParametersContainer.getRootDescriptor(In...
[truncated]
|
@llvm/pr-subscribers-backend-directx Author: None (joaosaffran) ChangesThis PR, moves the existing Root Signature Metadata Parsing logic used in Patch is 51.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149221.diff 11 Files Affected:
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index 0aa122f668ef1..729ea22d3c8ab 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -15,6 +15,8 @@
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/MC/DXContainerRootSignature.h"
namespace llvm {
class LLVMContext;
@@ -24,6 +26,96 @@ class Metadata;
namespace hlsl {
namespace rootsig {
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI =
+ mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
+ return CI->getZExtValue();
+ return std::nullopt;
+}
+
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
+ unsigned int OpId) {
+ if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
+ return CI->getValueAPF().convertToFloat();
+ return std::nullopt;
+}
+
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
+ unsigned int OpId) {
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
+ if (NodeText == nullptr)
+ return std::nullopt;
+ return NodeText->getString();
+}
+
+template <typename T>
+class RootSignatureValidationError
+ : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+ static char ID;
+ std::string ParamName;
+ T Value;
+
+ RootSignatureValidationError(StringRef ParamName, T Value)
+ : ParamName(ParamName.str()), Value(Value) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName << ": " << Value;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+ static char ID;
+ std::string Message;
+
+ GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+ void log(raw_ostream &OS) const override { OS << Message; }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+ static char ID;
+ std::string ElementName;
+
+ InvalidRSMetadataFormat(StringRef ElementName)
+ : ElementName(ElementName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid format for " << ElementName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+ static char ID;
+ std::string ParamName;
+
+ InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+ void log(raw_ostream &OS) const override {
+ OS << "Invalid value for " << ParamName;
+ }
+
+ std::error_code convertToErrorCode() const override {
+ return llvm::inconvertibleErrorCode();
+ }
+};
+
class MetadataBuilder {
public:
MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -49,6 +141,47 @@ class MetadataBuilder {
SmallVector<Metadata *> GeneratedMetadata;
};
+enum class RootSignatureElementKind {
+ Error = 0,
+ RootFlags = 1,
+ RootConstants = 2,
+ SRV = 3,
+ UAV = 4,
+ CBV = 5,
+ DescriptorTable = 6,
+ StaticSamplers = 7
+};
+
+class MetadataParser {
+public:
+ MetadataParser(MDNode *Root) : Root(Root) {}
+
+ /// Iterates through root signature and converts them into MapT
+ LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+ ParseRootSignature(uint32_t Version);
+
+private:
+ llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode);
+ llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode);
+ llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind);
+ llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode);
+ llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode);
+ llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element);
+ llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode);
+
+ llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
+
+ MDNode *Root;
+};
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index 4b6b42f7d74f7..14a2429ffcc78 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H
+
#include "llvm/BinaryFormat/DXContainer.h"
#include <cstdint>
#include <limits>
@@ -116,3 +119,5 @@ struct RootSignatureDesc {
};
} // namespace mcdxbc
} // namespace llvm
+
+#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index f7669f09dcecc..23c1815d438ad 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
+#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -20,6 +21,12 @@ namespace llvm {
namespace hlsl {
namespace rootsig {
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+
+template <typename T> char RootSignatureValidationError<T>::ID;
+
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
{"SRV", dxil::ResourceClass::SRV},
@@ -189,6 +196,514 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
return MDNode::get(Ctx, Operands);
}
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
+ if (RootFlagNode->getNumOperands() != 2)
+ return make_error<InvalidRSMetadataFormat>("RootFlag Element");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
+ RSD.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RootFlag");
+
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
+ if (RootConstantNode->getNumOperands() != 5)
+ return make_error<InvalidRSMetadataFormat>("RootConstants Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ // The parameter offset doesn't matter here - we recalculate it during
+ // serialization Header.ParameterOffset = 0;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ dxbc::RTS0::v1::RootConstants Constants;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
+ Constants.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
+ Constants.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
+ Constants.Num32BitValues = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Num32BitValues");
+
+ RSD.ParametersContainer.addParameter(Header, Constants);
+
+ return llvm::Error::success();
+}
+
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
+ assert(ElementKind == RootSignatureElementKind::SRV ||
+ ElementKind == RootSignatureElementKind::UAV ||
+ ElementKind == RootSignatureElementKind::CBV &&
+ "parseRootDescriptors should only be called with RootDescriptor"
+ "element kind.");
+ if (RootDescriptorNode->getNumOperands() != 5)
+ return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ switch (ElementKind) {
+ case RootSignatureElementKind::SRV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+ break;
+ case RootSignatureElementKind::UAV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+ break;
+ case RootSignatureElementKind::CBV:
+ Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+ break;
+ default:
+ llvm_unreachable("invalid Root Descriptor kind");
+ break;
+ }
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ dxbc::RTS0::v2::RootDescriptor Descriptor;
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
+ Descriptor.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
+ Descriptor.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (RSD.Version == 1) {
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return llvm::Error::success();
+ }
+ assert(RSD.Version > 1);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
+ Descriptor.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
+
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
+ if (RangeDescriptorNode->getNumOperands() != 6)
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+
+ dxbc::RTS0::v2::DescriptorRange Range;
+
+ std::optional<StringRef> ElementText =
+ extractMdStringValue(RangeDescriptorNode, 0);
+
+ if (!ElementText.has_value())
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+
+ Range.RangeType =
+ StringSwitch<uint32_t>(*ElementText)
+ .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
+ .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
+ .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
+ .Case("Sampler",
+ llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+ .Default(~0U);
+
+ if (Range.RangeType == ~0U)
+ return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+ *ElementText);
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
+ Range.NumDescriptors = *Val;
+ else
+ return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
+ Range.BaseShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
+ Range.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
+ Range.OffsetInDescriptorsFromTableStart = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>(
+ "OffsetInDescriptorsFromTableStart");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
+ Range.Flags = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
+
+ Table.Ranges.push_back(Range);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
+ const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
+ if (NumOperands < 2)
+ return make_error<InvalidRSMetadataFormat>("Descriptor Table");
+
+ dxbc::RTS0::v1::RootParameterHeader Header;
+ if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
+ Header.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ mcdxbc::DescriptorTable Table;
+ Header.ParameterType =
+ llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
+
+ for (unsigned int I = 2; I < NumOperands; I++) {
+ MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
+ if (Element == nullptr)
+ return make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.");
+
+ if (auto Err = parseDescriptorRange(Table, Element))
+ return Err;
+ }
+
+ RSD.ParametersContainer.addParameter(Header, Table);
+ return llvm::Error::success();
+}
+
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
+ if (StaticSamplerNode->getNumOperands() != 14)
+ return make_error<InvalidRSMetadataFormat>("Static Sampler");
+
+ dxbc::RTS0::v1::StaticSampler Sampler;
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
+ Sampler.Filter = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("Filter");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
+ Sampler.AddressU = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressU");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
+ Sampler.AddressV = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressV");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
+ Sampler.AddressW = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("AddressW");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
+ Sampler.MipLODBias = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MipLODBias");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
+ Sampler.MaxAnisotropy = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
+ Sampler.ComparisonFunc = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
+ Sampler.BorderColor = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
+ Sampler.MinLOD = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MinLOD");
+
+ if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
+ Sampler.MaxLOD = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("MaxLOD");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
+ Sampler.ShaderRegister = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
+ Sampler.RegisterSpace = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
+
+ if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
+ Sampler.ShaderVisibility = *Val;
+ else
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+
+ RSD.StaticSamplers.push_back(Sampler);
+ return llvm::Error::success();
+}
+
+llvm::Error
+MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
+ std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
+ if (!ElementText.has_value())
+ return make_error<InvalidRSMetadataFormat>("Root Element");
+
+ RootSignatureElementKind ElementKind =
+ StringSwitch<RootSignatureElementKind>(*ElementText)
+ .Case("RootFlags", RootSignatureElementKind::RootFlags)
+ .Case("RootConstants", RootSignatureElementKind::RootConstants)
+ .Case("RootCBV", RootSignatureElementKind::CBV)
+ .Case("RootSRV", RootSignatureElementKind::SRV)
+ .Case("RootUAV", RootSignatureElementKind::UAV)
+ .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+ .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
+ .Default(RootSignatureElementKind::Error);
+
+ switch (ElementKind) {
+
+ case RootSignatureElementKind::RootFlags:
+ return parseRootFlags(RSD, Element);
+ case RootSignatureElementKind::RootConstants:
+ return parseRootConstants(RSD, Element);
+ case RootSignatureElementKind::CBV:
+ case RootSignatureElementKind::SRV:
+ case RootSignatureElementKind::UAV:
+ return parseRootDescriptors(RSD, Element, ElementKind);
+ case RootSignatureElementKind::DescriptorTable:
+ return parseDescriptorTable(RSD, Element);
+ case RootSignatureElementKind::StaticSamplers:
+ return parseStaticSampler(RSD, Element);
+ case RootSignatureElementKind::Error:
+ return make_error<GenericRSMetadataError>(
+ "Invalid Root Signature Element:" + *ElementText);
+ }
+
+ llvm_unreachable("Unhandled RootSignatureElementKind enum.");
+}
+
+llvm::Error MetadataParser::validateRootSignature(
+ const llvm::mcdxbc::RootSignatureDesc &RSD) {
+ Error DeferredErrs = Error::success();
+ if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "Version", RSD.Version));
+ }
+
+ if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "RootFlags", RSD.Flags));
+ }
+
+ for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
+ if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ llvm::make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Info.Header.ShaderVisibility));
+
+ assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
+ "Invalid value for ParameterType");
+
+ switch (Info.Header.ParameterType) {
+
+ case llvm::to_underlying(dxbc::RootParameterType::CBV):
+ case llvm::to_underlying(dxbc::RootParameterType::UAV):
+ case llvm::to_underlying(dxbc::RootParameterType::SRV): {
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ RSD.ParametersContainer.getRootDescriptor(In...
[truncated]
|
; CHECK: error: Invalid Descriptor Range type: Invalid | ||
; CHECK: error: Invalid Descriptor Range type:Invalid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want to keep this as before
; CHECK: error: Invalid value for MaxLOD: 0 | ||
; CHECK: error: Invalid value for MaxLOD: nan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also seem strange
@@ -24,6 +26,96 @@ class Metadata; | |||
namespace hlsl { | |||
namespace rootsig { | |||
|
|||
inline std::optional<uint32_t> extractMdIntValue(MDNode *Node, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything to prevent this from being static functions in the .cpp
file?
I don't think we want to expose this api.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used to extract some metadata in DXILRootSignature
} | ||
|
||
template <typename T> | ||
class RootSignatureValidationError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all seems like new code?
It might be a little easier to review if we have just an NFC code move and then a separate change to error handling
public: | ||
MetadataParser(MDNode *Root) : Root(Root) {} | ||
|
||
/// Iterates through root signature and converts them into MapT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we specify what MapT is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this was a previous version, that I tried during development.
This PR, moves the existing Root Signature Metadata Parsing logic used in
DXILRootSignature
to the common library used by both frontend and backend. Closes: #145942