diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h index 0aa122f668ef1..6fa51eded52f0 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h @@ -15,6 +15,8 @@ #define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H #include "llvm/Frontend/HLSL/HLSLRootSignature.h" +#include "llvm/IR/Constants.h" +#include "llvm/MC/DXContainerRootSignature.h" namespace llvm { class LLVMContext; @@ -49,6 +51,48 @@ class MetadataBuilder { SmallVector GeneratedMetadata; }; +enum class RootSignatureElementKind { + Error = 0, + RootFlags = 1, + RootConstants = 2, + SRV = 3, + UAV = 4, + CBV = 5, + DescriptorTable = 6, + StaticSamplers = 7 +}; + +class MetadataParser { +public: + MetadataParser(MDNode *Root) : Root(Root) {} + + LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD); + +private: + bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode); + bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode); + bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *RootDescriptorNode, + RootSignatureElementKind ElementKind); + bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode); + bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode); + bool parseRootSignatureElement(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *Element); + bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode); + + bool validateRootSignature(LLVMContext *Ctx, + const llvm::mcdxbc::RootSignatureDesc &RSD); + + MDNode *Root; +}; + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index 4b6b42f7d74f7..14a2429ffcc78 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -6,6 +6,9 @@ // //===----------------------------------------------------------------------===// +#ifndef LLVM_MC_DXCONTAINERROOTSIGNATURE_H +#define LLVM_MC_DXCONTAINERROOTSIGNATURE_H + #include "llvm/BinaryFormat/DXContainer.h" #include #include @@ -116,3 +119,5 @@ struct RootSignatureDesc { }; } // namespace mcdxbc } // namespace llvm + +#endif // LLVM_MC_DXCONTAINERROOTSIGNATURE_H diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index f7669f09dcecc..53f59349ae029 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/HLSL/RootSignatureMetadata.h" +#include "llvm/Frontend/HLSL/RootSignatureValidations.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/ScopedPrinter.h" @@ -20,6 +22,42 @@ namespace llvm { namespace hlsl { namespace rootsig { +static std::optional extractMdIntValue(MDNode *Node, + unsigned int OpId) { + if (auto *CI = + mdconst::dyn_extract(Node->getOperand(OpId).get())) + return CI->getZExtValue(); + return std::nullopt; +} + +static std::optional extractMdFloatValue(MDNode *Node, + unsigned int OpId) { + if (auto *CI = mdconst::dyn_extract(Node->getOperand(OpId).get())) + return CI->getValueAPF().convertToFloat(); + return std::nullopt; +} + +static std::optional extractMdStringValue(MDNode *Node, + unsigned int OpId) { + MDString *NodeText = dyn_cast(Node->getOperand(OpId)); + if (NodeText == nullptr) + return std::nullopt; + return NodeText->getString(); +} + +static bool reportError(LLVMContext *Ctx, Twine Message, + DiagnosticSeverity Severity = DS_Error) { + Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); + return true; +} + +static bool reportValueError(LLVMContext *Ctx, Twine ParamName, + uint32_t Value) { + Ctx->diagnose(DiagnosticInfoGeneric( + "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); + return true; +} + static const EnumEntry ResourceClassNames[] = { {"CBV", dxil::ResourceClass::CBuffer}, {"SRV", dxil::ResourceClass::SRV}, @@ -189,6 +227,442 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { return MDNode::get(Ctx, Operands); } +bool MetadataParser::parseRootFlags(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *RootFlagNode) { + + if (RootFlagNode->getNumOperands() != 2) + return reportError(Ctx, "Invalid format for RootFlag Element"); + + if (std::optional Val = extractMdIntValue(RootFlagNode, 1)) + RSD.Flags = *Val; + else + return reportError(Ctx, "Invalid value for RootFlag"); + + return false; +} + +bool MetadataParser::parseRootConstants(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *RootConstantNode) { + + if (RootConstantNode->getNumOperands() != 5) + return reportError(Ctx, "Invalid format for RootConstants Element"); + + dxbc::RTS0::v1::RootParameterHeader Header; + // The parameter offset doesn't matter here - we recalculate it during + // serialization Header.ParameterOffset = 0; + Header.ParameterType = + llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); + + if (std::optional Val = extractMdIntValue(RootConstantNode, 1)) + Header.ShaderVisibility = *Val; + else + return reportError(Ctx, "Invalid value for ShaderVisibility"); + + dxbc::RTS0::v1::RootConstants Constants; + if (std::optional Val = extractMdIntValue(RootConstantNode, 2)) + Constants.ShaderRegister = *Val; + else + return reportError(Ctx, "Invalid value for ShaderRegister"); + + if (std::optional Val = extractMdIntValue(RootConstantNode, 3)) + Constants.RegisterSpace = *Val; + else + return reportError(Ctx, "Invalid value for RegisterSpace"); + + if (std::optional Val = extractMdIntValue(RootConstantNode, 4)) + Constants.Num32BitValues = *Val; + else + return reportError(Ctx, "Invalid value for Num32BitValues"); + + RSD.ParametersContainer.addParameter(Header, Constants); + + return false; +} + +bool MetadataParser::parseRootDescriptors( + LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, + MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) { + assert(ElementKind == RootSignatureElementKind::SRV || + ElementKind == RootSignatureElementKind::UAV || + ElementKind == RootSignatureElementKind::CBV && + "parseRootDescriptors should only be called with RootDescriptor " + "element kind."); + if (RootDescriptorNode->getNumOperands() != 5) + return reportError(Ctx, "Invalid format for Root Descriptor Element"); + + dxbc::RTS0::v1::RootParameterHeader Header; + switch (ElementKind) { + case RootSignatureElementKind::SRV: + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); + break; + case RootSignatureElementKind::UAV: + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); + break; + case RootSignatureElementKind::CBV: + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); + break; + default: + llvm_unreachable("invalid Root Descriptor kind"); + break; + } + + if (std::optional Val = extractMdIntValue(RootDescriptorNode, 1)) + Header.ShaderVisibility = *Val; + else + return reportError(Ctx, "Invalid value for ShaderVisibility"); + + dxbc::RTS0::v2::RootDescriptor Descriptor; + if (std::optional Val = extractMdIntValue(RootDescriptorNode, 2)) + Descriptor.ShaderRegister = *Val; + else + return reportError(Ctx, "Invalid value for ShaderRegister"); + + if (std::optional Val = extractMdIntValue(RootDescriptorNode, 3)) + Descriptor.RegisterSpace = *Val; + else + return reportError(Ctx, "Invalid value for RegisterSpace"); + + if (RSD.Version == 1) { + RSD.ParametersContainer.addParameter(Header, Descriptor); + return false; + } + assert(RSD.Version > 1); + + if (std::optional Val = extractMdIntValue(RootDescriptorNode, 4)) + Descriptor.Flags = *Val; + else + return reportError(Ctx, "Invalid value for Root Descriptor Flags"); + + RSD.ParametersContainer.addParameter(Header, Descriptor); + return false; +} + +bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx, + mcdxbc::DescriptorTable &Table, + MDNode *RangeDescriptorNode) { + + if (RangeDescriptorNode->getNumOperands() != 6) + return reportError(Ctx, "Invalid format for Descriptor Range"); + + dxbc::RTS0::v2::DescriptorRange Range; + + std::optional ElementText = + extractMdStringValue(RangeDescriptorNode, 0); + + if (!ElementText.has_value()) + return reportError(Ctx, "Descriptor Range, first element is not a string."); + + Range.RangeType = + StringSwitch(*ElementText) + .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) + .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) + .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) + .Case("Sampler", + llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) + .Default(~0U); + + if (Range.RangeType == ~0U) + return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); + + if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 1)) + Range.NumDescriptors = *Val; + else + return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); + + if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 2)) + Range.BaseShaderRegister = *Val; + else + return reportError(Ctx, "Invalid value for BaseShaderRegister"); + + if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 3)) + Range.RegisterSpace = *Val; + else + return reportError(Ctx, "Invalid value for RegisterSpace"); + + if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 4)) + Range.OffsetInDescriptorsFromTableStart = *Val; + else + return reportError(Ctx, + "Invalid value for OffsetInDescriptorsFromTableStart"); + + if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 5)) + Range.Flags = *Val; + else + return reportError(Ctx, "Invalid value for Descriptor Range Flags"); + + Table.Ranges.push_back(Range); + return false; +} + +bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *DescriptorTableNode) { + const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); + if (NumOperands < 2) + return reportError(Ctx, "Invalid format for Descriptor Table"); + + dxbc::RTS0::v1::RootParameterHeader Header; + if (std::optional Val = extractMdIntValue(DescriptorTableNode, 1)) + Header.ShaderVisibility = *Val; + else + return reportError(Ctx, "Invalid value for ShaderVisibility"); + + mcdxbc::DescriptorTable Table; + Header.ParameterType = + llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); + + for (unsigned int I = 2; I < NumOperands; I++) { + MDNode *Element = dyn_cast(DescriptorTableNode->getOperand(I)); + if (Element == nullptr) + return reportError(Ctx, "Missing Root Element Metadata Node."); + + if (parseDescriptorRange(Ctx, Table, Element)) + return true; + } + + RSD.ParametersContainer.addParameter(Header, Table); + return false; +} + +bool MetadataParser::parseStaticSampler(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *StaticSamplerNode) { + if (StaticSamplerNode->getNumOperands() != 14) + return reportError(Ctx, "Invalid format for Static Sampler"); + + dxbc::RTS0::v1::StaticSampler Sampler; + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 1)) + Sampler.Filter = *Val; + else + return reportError(Ctx, "Invalid value for Filter"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 2)) + Sampler.AddressU = *Val; + else + return reportError(Ctx, "Invalid value for AddressU"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 3)) + Sampler.AddressV = *Val; + else + return reportError(Ctx, "Invalid value for AddressV"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 4)) + Sampler.AddressW = *Val; + else + return reportError(Ctx, "Invalid value for AddressW"); + + if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 5)) + Sampler.MipLODBias = *Val; + else + return reportError(Ctx, "Invalid value for MipLODBias"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 6)) + Sampler.MaxAnisotropy = *Val; + else + return reportError(Ctx, "Invalid value for MaxAnisotropy"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 7)) + Sampler.ComparisonFunc = *Val; + else + return reportError(Ctx, "Invalid value for ComparisonFunc "); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 8)) + Sampler.BorderColor = *Val; + else + return reportError(Ctx, "Invalid value for ComparisonFunc "); + + if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 9)) + Sampler.MinLOD = *Val; + else + return reportError(Ctx, "Invalid value for MinLOD"); + + if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 10)) + Sampler.MaxLOD = *Val; + else + return reportError(Ctx, "Invalid value for MaxLOD"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 11)) + Sampler.ShaderRegister = *Val; + else + return reportError(Ctx, "Invalid value for ShaderRegister"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 12)) + Sampler.RegisterSpace = *Val; + else + return reportError(Ctx, "Invalid value for RegisterSpace"); + + if (std::optional Val = extractMdIntValue(StaticSamplerNode, 13)) + Sampler.ShaderVisibility = *Val; + else + return reportError(Ctx, "Invalid value for ShaderVisibility"); + + RSD.StaticSamplers.push_back(Sampler); + return false; +} + +bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD, + MDNode *Element) { + std::optional ElementText = extractMdStringValue(Element, 0); + if (!ElementText.has_value()) + return reportError(Ctx, "Invalid format for Root Element"); + + RootSignatureElementKind ElementKind = + StringSwitch(*ElementText) + .Case("RootFlags", RootSignatureElementKind::RootFlags) + .Case("RootConstants", RootSignatureElementKind::RootConstants) + .Case("RootCBV", RootSignatureElementKind::CBV) + .Case("RootSRV", RootSignatureElementKind::SRV) + .Case("RootUAV", RootSignatureElementKind::UAV) + .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable) + .Case("StaticSampler", RootSignatureElementKind::StaticSamplers) + .Default(RootSignatureElementKind::Error); + + switch (ElementKind) { + + case RootSignatureElementKind::RootFlags: + return parseRootFlags(Ctx, RSD, Element); + case RootSignatureElementKind::RootConstants: + return parseRootConstants(Ctx, RSD, Element); + case RootSignatureElementKind::CBV: + case RootSignatureElementKind::SRV: + case RootSignatureElementKind::UAV: + return parseRootDescriptors(Ctx, RSD, Element, ElementKind); + case RootSignatureElementKind::DescriptorTable: + return parseDescriptorTable(Ctx, RSD, Element); + case RootSignatureElementKind::StaticSamplers: + return parseStaticSampler(Ctx, RSD, Element); + case RootSignatureElementKind::Error: + return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); + } + + llvm_unreachable("Unhandled RootSignatureElementKind enum."); +} + +bool MetadataParser::validateRootSignature( + LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) { + if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { + return reportValueError(Ctx, "Version", RSD.Version); + } + + if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { + return reportValueError(Ctx, "RootFlags", RSD.Flags); + } + + for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { + if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) + return reportValueError(Ctx, "ShaderVisibility", + Info.Header.ShaderVisibility); + + assert(dxbc::isValidParameterType(Info.Header.ParameterType) && + "Invalid value for ParameterType"); + + switch (Info.Header.ParameterType) { + + case llvm::to_underlying(dxbc::RootParameterType::CBV): + case llvm::to_underlying(dxbc::RootParameterType::UAV): + case llvm::to_underlying(dxbc::RootParameterType::SRV): { + const dxbc::RTS0::v2::RootDescriptor &Descriptor = + RSD.ParametersContainer.getRootDescriptor(Info.Location); + if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) + return reportValueError(Ctx, "ShaderRegister", + Descriptor.ShaderRegister); + + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) + return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); + + if (RSD.Version > 1) { + if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, + Descriptor.Flags)) + return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); + } + break; + } + case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + const mcdxbc::DescriptorTable &Table = + RSD.ParametersContainer.getDescriptorTable(Info.Location); + for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { + if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) + return reportValueError(Ctx, "RangeType", Range.RangeType); + + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) + return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); + + if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) + return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); + + if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, Range.Flags)) + return reportValueError(Ctx, "DescriptorFlag", Range.Flags); + } + break; + } + } + } + + for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { + if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) + return reportValueError(Ctx, "Filter", Sampler.Filter); + + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) + return reportValueError(Ctx, "AddressU", Sampler.AddressU); + + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) + return reportValueError(Ctx, "AddressV", Sampler.AddressV); + + if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) + return reportValueError(Ctx, "AddressW", Sampler.AddressW); + + if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) + return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); + + if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) + return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); + + if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) + return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); + + if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) + return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); + + if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) + return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); + + if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) + return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); + + if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) + return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); + + if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) + return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); + + if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) + return reportValueError(Ctx, "ShaderVisibility", + Sampler.ShaderVisibility); + } + + return false; +} + +bool MetadataParser::ParseRootSignature(LLVMContext *Ctx, + mcdxbc::RootSignatureDesc &RSD) { + bool HasError = false; + + // Loop through the Root Elements of the root signature. + for (const auto &Operand : Root->operands()) { + MDNode *Element = dyn_cast(Operand); + if (Element == nullptr) + return reportError(Ctx, "Missing Root Element Metadata Node."); + + HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) || + validateRootSignature(Ctx, RSD); + } + + return HasError; +} } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index dfc81626da01f..ebdfcaa566b51 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/BinaryFormat/DXContainer.h" +#include "llvm/Frontend/HLSL/RootSignatureMetadata.h" #include "llvm/Frontend/HLSL/RootSignatureValidations.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" @@ -29,25 +30,10 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include -#include -#include using namespace llvm; using namespace llvm::dxil; -static bool reportError(LLVMContext *Ctx, Twine Message, - DiagnosticSeverity Severity = DS_Error) { - Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); - return true; -} - -static bool reportValueError(LLVMContext *Ctx, Twine ParamName, - uint32_t Value) { - Ctx->diagnose(DiagnosticInfoGeneric( - "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error)); - return true; -} - static std::optional extractMdIntValue(MDNode *Node, unsigned int OpId) { if (auto *CI = @@ -56,453 +42,10 @@ static std::optional extractMdIntValue(MDNode *Node, return std::nullopt; } -static std::optional extractMdFloatValue(MDNode *Node, - unsigned int OpId) { - if (auto *CI = mdconst::dyn_extract(Node->getOperand(OpId).get())) - return CI->getValueAPF().convertToFloat(); - return std::nullopt; -} - -static std::optional extractMdStringValue(MDNode *Node, - unsigned int OpId) { - MDString *NodeText = dyn_cast(Node->getOperand(OpId)); - if (NodeText == nullptr) - return std::nullopt; - return NodeText->getString(); -} - -static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *RootFlagNode) { - - if (RootFlagNode->getNumOperands() != 2) - return reportError(Ctx, "Invalid format for RootFlag Element"); - - if (std::optional Val = extractMdIntValue(RootFlagNode, 1)) - RSD.Flags = *Val; - else - return reportError(Ctx, "Invalid value for RootFlag"); - - return false; -} - -static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *RootConstantNode) { - - if (RootConstantNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for RootConstants Element"); - - dxbc::RTS0::v1::RootParameterHeader Header; - // The parameter offset doesn't matter here - we recalculate it during - // serialization Header.ParameterOffset = 0; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); - - if (std::optional Val = extractMdIntValue(RootConstantNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); - - dxbc::RTS0::v1::RootConstants Constants; - if (std::optional Val = extractMdIntValue(RootConstantNode, 2)) - Constants.ShaderRegister = *Val; - else - return reportError(Ctx, "Invalid value for ShaderRegister"); - - if (std::optional Val = extractMdIntValue(RootConstantNode, 3)) - Constants.RegisterSpace = *Val; - else - return reportError(Ctx, "Invalid value for RegisterSpace"); - - if (std::optional Val = extractMdIntValue(RootConstantNode, 4)) - Constants.Num32BitValues = *Val; - else - return reportError(Ctx, "Invalid value for Num32BitValues"); - - RSD.ParametersContainer.addParameter(Header, Constants); - - return false; -} - -static bool parseRootDescriptors(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *RootDescriptorNode, - RootSignatureElementKind ElementKind) { - assert(ElementKind == RootSignatureElementKind::SRV || - ElementKind == RootSignatureElementKind::UAV || - ElementKind == RootSignatureElementKind::CBV && - "parseRootDescriptors should only be called with RootDescriptor " - "element kind."); - if (RootDescriptorNode->getNumOperands() != 5) - return reportError(Ctx, "Invalid format for Root Descriptor Element"); - - dxbc::RTS0::v1::RootParameterHeader Header; - switch (ElementKind) { - case RootSignatureElementKind::SRV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV); - break; - case RootSignatureElementKind::UAV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV); - break; - case RootSignatureElementKind::CBV: - Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV); - break; - default: - llvm_unreachable("invalid Root Descriptor kind"); - break; - } - - if (std::optional Val = extractMdIntValue(RootDescriptorNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); - - dxbc::RTS0::v2::RootDescriptor Descriptor; - if (std::optional Val = extractMdIntValue(RootDescriptorNode, 2)) - Descriptor.ShaderRegister = *Val; - else - return reportError(Ctx, "Invalid value for ShaderRegister"); - - if (std::optional Val = extractMdIntValue(RootDescriptorNode, 3)) - Descriptor.RegisterSpace = *Val; - else - return reportError(Ctx, "Invalid value for RegisterSpace"); - - if (RSD.Version == 1) { - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; - } - assert(RSD.Version > 1); - - if (std::optional Val = extractMdIntValue(RootDescriptorNode, 4)) - Descriptor.Flags = *Val; - else - return reportError(Ctx, "Invalid value for Root Descriptor Flags"); - - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; -} - -static bool parseDescriptorRange(LLVMContext *Ctx, - mcdxbc::DescriptorTable &Table, - MDNode *RangeDescriptorNode) { - - if (RangeDescriptorNode->getNumOperands() != 6) - return reportError(Ctx, "Invalid format for Descriptor Range"); - - dxbc::RTS0::v2::DescriptorRange Range; - - std::optional ElementText = - extractMdStringValue(RangeDescriptorNode, 0); - - if (!ElementText.has_value()) - return reportError(Ctx, "Descriptor Range, first element is not a string."); - - Range.RangeType = - StringSwitch(*ElementText) - .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV)) - .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV)) - .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV)) - .Case("Sampler", - llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) - .Default(~0U); - - if (Range.RangeType == ~0U) - return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); - - if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 1)) - Range.NumDescriptors = *Val; - else - return reportError(Ctx, "Invalid value for Number of Descriptor in Range"); - - if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 2)) - Range.BaseShaderRegister = *Val; - else - return reportError(Ctx, "Invalid value for BaseShaderRegister"); - - if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 3)) - Range.RegisterSpace = *Val; - else - return reportError(Ctx, "Invalid value for RegisterSpace"); - - if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 4)) - Range.OffsetInDescriptorsFromTableStart = *Val; - else - return reportError(Ctx, - "Invalid value for OffsetInDescriptorsFromTableStart"); - - if (std::optional Val = extractMdIntValue(RangeDescriptorNode, 5)) - Range.Flags = *Val; - else - return reportError(Ctx, "Invalid value for Descriptor Range Flags"); - - Table.Ranges.push_back(Range); - return false; -} - -static bool parseDescriptorTable(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *DescriptorTableNode) { - const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); - if (NumOperands < 2) - return reportError(Ctx, "Invalid format for Descriptor Table"); - - dxbc::RTS0::v1::RootParameterHeader Header; - if (std::optional Val = extractMdIntValue(DescriptorTableNode, 1)) - Header.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); - - mcdxbc::DescriptorTable Table; - Header.ParameterType = - llvm::to_underlying(dxbc::RootParameterType::DescriptorTable); - - for (unsigned int I = 2; I < NumOperands; I++) { - MDNode *Element = dyn_cast(DescriptorTableNode->getOperand(I)); - if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); - - if (parseDescriptorRange(Ctx, Table, Element)) - return true; - } - - RSD.ParametersContainer.addParameter(Header, Table); - return false; -} - -static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *StaticSamplerNode) { - if (StaticSamplerNode->getNumOperands() != 14) - return reportError(Ctx, "Invalid format for Static Sampler"); - - dxbc::RTS0::v1::StaticSampler Sampler; - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 1)) - Sampler.Filter = *Val; - else - return reportError(Ctx, "Invalid value for Filter"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 2)) - Sampler.AddressU = *Val; - else - return reportError(Ctx, "Invalid value for AddressU"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 3)) - Sampler.AddressV = *Val; - else - return reportError(Ctx, "Invalid value for AddressV"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 4)) - Sampler.AddressW = *Val; - else - return reportError(Ctx, "Invalid value for AddressW"); - - if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 5)) - Sampler.MipLODBias = *Val; - else - return reportError(Ctx, "Invalid value for MipLODBias"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 6)) - Sampler.MaxAnisotropy = *Val; - else - return reportError(Ctx, "Invalid value for MaxAnisotropy"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 7)) - Sampler.ComparisonFunc = *Val; - else - return reportError(Ctx, "Invalid value for ComparisonFunc "); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 8)) - Sampler.BorderColor = *Val; - else - return reportError(Ctx, "Invalid value for ComparisonFunc "); - - if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 9)) - Sampler.MinLOD = *Val; - else - return reportError(Ctx, "Invalid value for MinLOD"); - - if (std::optional Val = extractMdFloatValue(StaticSamplerNode, 10)) - Sampler.MaxLOD = *Val; - else - return reportError(Ctx, "Invalid value for MaxLOD"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 11)) - Sampler.ShaderRegister = *Val; - else - return reportError(Ctx, "Invalid value for ShaderRegister"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 12)) - Sampler.RegisterSpace = *Val; - else - return reportError(Ctx, "Invalid value for RegisterSpace"); - - if (std::optional Val = extractMdIntValue(StaticSamplerNode, 13)) - Sampler.ShaderVisibility = *Val; - else - return reportError(Ctx, "Invalid value for ShaderVisibility"); - - RSD.StaticSamplers.push_back(Sampler); - return false; -} - -static bool parseRootSignatureElement(LLVMContext *Ctx, - mcdxbc::RootSignatureDesc &RSD, - MDNode *Element) { - std::optional ElementText = extractMdStringValue(Element, 0); - if (!ElementText.has_value()) - return reportError(Ctx, "Invalid format for Root Element"); - - RootSignatureElementKind ElementKind = - StringSwitch(*ElementText) - .Case("RootFlags", RootSignatureElementKind::RootFlags) - .Case("RootConstants", RootSignatureElementKind::RootConstants) - .Case("RootCBV", RootSignatureElementKind::CBV) - .Case("RootSRV", RootSignatureElementKind::SRV) - .Case("RootUAV", RootSignatureElementKind::UAV) - .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable) - .Case("StaticSampler", RootSignatureElementKind::StaticSamplers) - .Default(RootSignatureElementKind::Error); - - switch (ElementKind) { - - case RootSignatureElementKind::RootFlags: - return parseRootFlags(Ctx, RSD, Element); - case RootSignatureElementKind::RootConstants: - return parseRootConstants(Ctx, RSD, Element); - case RootSignatureElementKind::CBV: - case RootSignatureElementKind::SRV: - case RootSignatureElementKind::UAV: - return parseRootDescriptors(Ctx, RSD, Element, ElementKind); - case RootSignatureElementKind::DescriptorTable: - return parseDescriptorTable(Ctx, RSD, Element); - case RootSignatureElementKind::StaticSamplers: - return parseStaticSampler(Ctx, RSD, Element); - case RootSignatureElementKind::Error: - return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText); - } - - llvm_unreachable("Unhandled RootSignatureElementKind enum."); -} - -static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, - MDNode *Node) { - bool HasError = false; - - // Loop through the Root Elements of the root signature. - for (const auto &Operand : Node->operands()) { - MDNode *Element = dyn_cast(Operand); - if (Element == nullptr) - return reportError(Ctx, "Missing Root Element Metadata Node."); - - HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element); - } - - return HasError; -} - -static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { - - if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) { - return reportValueError(Ctx, "Version", RSD.Version); - } - - if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) { - return reportValueError(Ctx, "RootFlags", RSD.Flags); - } - - for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { - if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Info.Header.ShaderVisibility); - - assert(dxbc::isValidParameterType(Info.Header.ParameterType) && - "Invalid value for ParameterType"); - - switch (Info.Header.ParameterType) { - - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): { - const dxbc::RTS0::v2::RootDescriptor &Descriptor = - RSD.ParametersContainer.getRootDescriptor(Info.Location); - if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", - Descriptor.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); - - if (RSD.Version > 1) { - if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) - return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags); - } - break; - } - case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { - const mcdxbc::DescriptorTable &Table = - RSD.ParametersContainer.getDescriptorTable(Info.Location); - for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) { - if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType)) - return reportValueError(Ctx, "RangeType", Range.RangeType); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace); - - if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors)) - return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors); - - if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, Range.Flags)) - return reportValueError(Ctx, "DescriptorFlag", Range.Flags); - } - break; - } - } - } - - for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) { - if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter)) - return reportValueError(Ctx, "Filter", Sampler.Filter); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU)) - return reportValueError(Ctx, "AddressU", Sampler.AddressU); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV)) - return reportValueError(Ctx, "AddressV", Sampler.AddressV); - - if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW)) - return reportValueError(Ctx, "AddressW", Sampler.AddressW); - - if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias)) - return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias); - - if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy)) - return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy); - - if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc)) - return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc); - - if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor)) - return reportValueError(Ctx, "BorderColor", Sampler.BorderColor); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD)) - return reportValueError(Ctx, "MinLOD", Sampler.MinLOD); - - if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD)) - return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD); - - if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister); - - if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace); - - if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Sampler.ShaderVisibility); - } - - return false; +static bool reportError(LLVMContext *Ctx, Twine Message, + DiagnosticSeverity Severity = DS_Error) { + Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity)); + return true; } static SmallDenseMap @@ -584,7 +127,9 @@ analyzeModule(Module &M) { // static sampler offset is calculated when writting dxcontainer. RSD.StaticSamplersOffset = 0u; - if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) { + hlsl::rootsig::MetadataParser MDParser(RootElementListNode); + + if (MDParser.ParseRootSignature(Ctx, RSD)) { return RSDMap; } diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index fc39b38258df8..254b7ff504633 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -26,17 +26,6 @@ namespace llvm { namespace dxil { -enum class RootSignatureElementKind { - Error = 0, - RootFlags = 1, - RootConstants = 2, - SRV = 3, - UAV = 4, - CBV = 5, - DescriptorTable = 6, - StaticSamplers = 7 -}; - class RootSignatureBindingInfo { private: SmallDenseMap FuncToRsMap;