diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 81294f101e24f..ec0575ba3ecb3 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -741,6 +741,19 @@ class ASTContext final { unsigned previousGeneration, llvm::TinyPtrVector &methods); + /// Load derivative function configurations for the given + /// AbstractFunctionDecl. + /// + /// \param originalAFD The declaration whose derivative function + /// configurations should be loaded. + /// + /// \param previousGeneration The previous generation number. The AST already + /// contains derivative function configurations loaded from any generation up + /// to and including this one. + void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results); + /// Retrieve the Clang module loader for this ASTContext. /// /// If there is no Clang module loader, returns a null pointer. diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index e3f3df4db22d7..e62d699f86d02 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -5796,6 +5796,7 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { private: ParameterList *Params; +private: /// The generation at which we last loaded derivative function configurations. unsigned DerivativeFunctionConfigGeneration = 0; /// Prepare to traverse the list of derivative function configurations. @@ -5810,6 +5811,13 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { struct DerivativeFunctionConfigurationList; DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr; +public: + /// Get all derivative function configurations. + ArrayRef getDerivativeFunctionConfigurations(); + + /// Add the given derivative function configuration. + void addDerivativeFunctionConfiguration(AutoDiffConfig config); + protected: // If a function has a body at all, we have either a parsed body AST node or // we have saved the end location of the unparsed body. @@ -6129,12 +6137,6 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { /// constructor. bool hasDynamicSelfResult() const; - /// Get all derivative function configurations. - ArrayRef getDerivativeFunctionConfigurations(); - - /// Add the given derivative function configuration. - void addDerivativeFunctionConfiguration(AutoDiffConfig config); - using DeclContext::operator new; using Decl::getASTContext; }; diff --git a/include/swift/AST/ModuleLoader.h b/include/swift/AST/ModuleLoader.h index 55b6008cf68ce..8e13dae85fdaf 100644 --- a/include/swift/AST/ModuleLoader.h +++ b/include/swift/AST/ModuleLoader.h @@ -36,6 +36,7 @@ class DependencyCollector; namespace swift { class AbstractFunctionDecl; +struct AutoDiffConfig; class ClangImporterOptions; class ClassDecl; class FileUnit; @@ -153,6 +154,23 @@ class ModuleLoader { unsigned previousGeneration, llvm::TinyPtrVector &methods) = 0; + /// Load derivative function configurations for the given + /// AbstractFunctionDecl. + /// + /// \param originalAFD The declaration whose derivative function + /// configurations should be loaded. + /// + /// \param previousGeneration The previous generation number. The AST already + /// contains derivative function configurations loaded from any generation up + /// to and including this one. + /// + /// \param results The result list of derivative function configurations. + /// This list will be extended with any methods found in subsequent + /// generations. + virtual void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) {}; + /// Verify all modules loaded by this loader. virtual void verifyAllModules() { } diff --git a/include/swift/Serialization/SerializedModuleLoader.h b/include/swift/Serialization/SerializedModuleLoader.h index b508bb08dfd56..d065e7236aa1d 100644 --- a/include/swift/Serialization/SerializedModuleLoader.h +++ b/include/swift/Serialization/SerializedModuleLoader.h @@ -184,6 +184,10 @@ class SerializedModuleLoaderBase : public ModuleLoader { unsigned previousGeneration, llvm::TinyPtrVector &methods) override; + virtual void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) override; + virtual void verifyAllModules() override; }; diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 231eec4c32a5a..037e08199fedb 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -1471,6 +1471,17 @@ void ASTContext::loadObjCMethods( } } +void ASTContext::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned previousGeneration, + llvm::SetVector &results) { + PrettyStackTraceDecl stackTrace( + "loading derivative function configurations for", originalAFD); + for (auto &loader : getImpl().ModuleLoaders) { + loader->loadDerivativeFunctionConfigurations(originalAFD, + previousGeneration, results); + } +} + void ASTContext::verifyAllLoadedModules() const { #ifndef NDEBUG FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules"); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 597308fcdeb81..11c043dd6b7d8 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -7099,8 +7099,10 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() { prepareDerivativeFunctionConfigurations(); auto &ctx = getASTContext(); if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) { - // TODO(TF-1100): Upstream derivative function configuration serialization - // logic. + unsigned previousGeneration = DerivativeFunctionConfigGeneration; + DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration(); + ctx.loadDerivativeFunctionConfigurations(this, previousGeneration, + *DerivativeFunctionConfigs); } return DerivativeFunctionConfigs->getArrayRef(); } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 1aa1385fb6b1b..b447050014ea9 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3970,6 +3970,10 @@ llvm::Expected DifferentiableAttributeTypeCheckRequest::evaluate( return nullptr; } getterDecl->getAttrs().add(newAttr); + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + getterDecl->addDerivativeFunctionConfiguration( + {resolvedDiffParamIndices, resultIndices, derivativeGenSig}); return resolvedDiffParamIndices; } // Reject duplicate `@differentiable` attributes. @@ -4341,6 +4345,12 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, return true; } + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(Ctx, 1, {0}); + originalAFD->addDerivativeFunctionConfiguration( + {resolvedDiffParamIndices, resultIndices, + derivative->getGenericSignature()}); + return false; } diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 175652ab3090a..059165ae6e5b8 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -438,6 +438,11 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, } else { witness->getAttrs().add(newAttr); success = true; + // Register derivative function configuration. + auto *resultIndices = IndexSubset::get(ctx, 1, {0}); + witnessAFD->addDerivativeFunctionConfiguration( + {newAttr->getParameterIndices(), resultIndices, + newAttr->getDerivativeGenericSignature()}); } } if (!success) { diff --git a/lib/Serialization/DeclTypeRecordNodes.def b/lib/Serialization/DeclTypeRecordNodes.def index 3bb2f0bc6fcee..dd014fdbfa085 100644 --- a/lib/Serialization/DeclTypeRecordNodes.def +++ b/lib/Serialization/DeclTypeRecordNodes.def @@ -192,6 +192,8 @@ OTHER(XREF_OPAQUE_RETURN_TYPE_PATH_PIECE, 252) OTHER(CLANG_TYPE, 253) +OTHER(DERIVATIVE_FUNCTION_CONFIGURATION, 254) + #undef RECORD #undef DECLTYPERECORDNODES_HAS_RECORD_VAL #undef RECORD_VAL diff --git a/lib/Serialization/ModuleFile.cpp b/lib/Serialization/ModuleFile.cpp index b66a76d892039..0202670f10895 100644 --- a/lib/Serialization/ModuleFile.cpp +++ b/lib/Serialization/ModuleFile.cpp @@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef fields, StringRef blobData) { base + sizeof(uint32_t), base)); } +/// Used to deserialize entries in the on-disk derivative function configuration +/// table. +class ModuleFile::DerivativeFunctionConfigTableInfo { +public: + using internal_key_type = StringRef; + using external_key_type = internal_key_type; + using data_type = SmallVector, 8>; + using hash_value_type = uint32_t; + using offset_type = unsigned; + + external_key_type GetExternalKey(internal_key_type ID) { return ID; } + + internal_key_type GetInternalKey(external_key_type ID) { return ID; } + + hash_value_type ComputeHash(internal_key_type key) { + return llvm::djbHash(key, SWIFTMODULE_HASH_SEED); + } + + static bool EqualKey(internal_key_type lhs, internal_key_type rhs) { + return lhs == rhs; + } + + static std::pair ReadKeyDataLength(const uint8_t *&data) { + unsigned keyLength = endian::readNext(data); + unsigned dataLength = endian::readNext(data); + return {keyLength, dataLength}; + } + + static internal_key_type ReadKey(const uint8_t *data, unsigned length) { + return StringRef(reinterpret_cast(data), length); + } + + static data_type ReadData(internal_key_type key, const uint8_t *data, + unsigned length) { + data_type result; + const uint8_t *limit = data + length; + while (data < limit) { + DeclID genSigId = endian::readNext(data); + int32_t nameLength = endian::readNext(data); + StringRef mangledName(reinterpret_cast(data), nameLength); + data += nameLength; + result.push_back({mangledName, genSigId}); + } + return result; + } +}; + +std::unique_ptr +ModuleFile::readDerivativeFunctionConfigTable(ArrayRef fields, + StringRef blobData) { + uint32_t tableOffset; + index_block::DerivativeFunctionConfigTableLayout::readRecord(fields, + tableOffset); + auto base = reinterpret_cast(blobData.data()); + + using OwnedTable = std::unique_ptr; + return OwnedTable(SerializedDerivativeFunctionConfigTable::Create( + base + tableOffset, base + sizeof(uint32_t), base)); +} + bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) { if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) { // FIXME this drops the error on the floor. @@ -1015,6 +1075,10 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) { case index_block::OBJC_METHODS: ObjCMethods = readObjCMethodTable(scratch, blobData); break; + case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS: + DerivativeFunctionConfigurations = + readDerivativeFunctionConfigTable(scratch, blobData); + break; case index_block::ENTRY_POINT: assert(blobData.empty()); setEntryPointClassID(scratch.front()); @@ -2405,6 +2469,34 @@ void ModuleFile::loadObjCMethods( } } +void ModuleFile::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, + llvm::SetVector &results) { + if (!DerivativeFunctionConfigurations) + return; + auto &ctx = originalAFD->getASTContext(); + Mangle::ASTMangler Mangler; + auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, ""); + auto configs = DerivativeFunctionConfigurations->find(mangledName); + if (configs == DerivativeFunctionConfigurations->end()) + return; + for (auto entry : *configs) { + auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first); + auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second); + if (!derivativeGenSigOrError) { + if (!getContext().LangOpts.EnableDeserializationRecovery) + fatal(derivativeGenSigOrError.takeError()); + llvm::consumeError(derivativeGenSigOrError.takeError()); + } + auto derivativeGenSig = derivativeGenSigOrError.get(); + // NOTE(TF-1038): Result indices are currently unsupported in derivative + // registration attributes. In the meantime, always use `{0}` (wrt the + // first and only result). + auto resultIndices = IndexSubset::get(ctx, 1, {0}); + results.insert({parameterIndices, resultIndices, derivativeGenSig}); + } +} + TinyPtrVector ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N, uint64_t contextData) { diff --git a/lib/Serialization/ModuleFile.h b/lib/Serialization/ModuleFile.h index c52d8452ba428..b594c524a7745 100644 --- a/lib/Serialization/ModuleFile.h +++ b/lib/Serialization/ModuleFile.h @@ -417,6 +417,12 @@ class ModuleFile llvm::OnDiskIterableChainedHashTable; std::unique_ptr DeclUSRsTable; + class DerivativeFunctionConfigTableInfo; + using SerializedDerivativeFunctionConfigTable = + llvm::OnDiskIterableChainedHashTable; + std::unique_ptr + DerivativeFunctionConfigurations; + /// A blob of 0 terminated string segments referenced in \c SourceLocsTextData StringRef SourceLocsTextData; @@ -550,6 +556,12 @@ class ModuleFile std::unique_ptr readDeclMembersTable(ArrayRef fields, StringRef blobData); + /// Read an on-disk derivative function configuration table stored in + /// index_block::DerivativeFunctionConfigTableLayout format. + std::unique_ptr + readDerivativeFunctionConfigTable(ArrayRef fields, + StringRef blobData); + /// Reads the index block, which contains global tables. /// /// Returns false if there was an error. @@ -774,6 +786,12 @@ class ModuleFile bool isInstanceMethod, llvm::TinyPtrVector &methods); + /// Loads all derivative function configurations for the given + /// AbstractFunctionDecl. + void loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, + llvm::SetVector &results); + /// Reports all class members in the module to the given consumer. /// /// This is intended for use with id-style lookup and code completion. diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index a26403ac9db14..14db3010597cc 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 550; // linear_function, linear_function_extract +const uint16_t SWIFTMODULE_VERSION_MINOR = 551; // derivative function configurations /// A standard hash seed used for all string hashes in a serialized module. /// @@ -1934,6 +1934,10 @@ namespace index_block { /// produce Objective-C methods. OBJC_METHODS, + /// The derivative function configuration table, which maps original + /// function declaration names to derivative function configurations. + DERIVATIVE_FUNCTION_CONFIGURATIONS, + ENTRY_POINT, LOCAL_DECL_CONTEXT_OFFSETS, LOCAL_TYPE_DECLS, @@ -1998,6 +2002,12 @@ namespace index_block { BCBlob // map from member DeclBaseNames to offsets of DECL_MEMBERS records >; + using DerivativeFunctionConfigTableLayout = BCRecordLayout< + DERIVATIVE_FUNCTION_CONFIGURATIONS, // record ID + BCVBR<16>, // table offset within the blob (see below) + BCBlob // map from original declaration names to derivative configs + >; + using EntryPointLayout = BCRecordLayout< ENTRY_POINT, DeclIDField // the ID of the main class; 0 if there was a main source file diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index a091fafb0ea35..a24206eaa28df 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -770,6 +770,7 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(index_block, CLASS_MEMBERS_FOR_DYNAMIC_LOOKUP); BLOCK_RECORD(index_block, OPERATOR_METHODS); BLOCK_RECORD(index_block, OBJC_METHODS); + BLOCK_RECORD(index_block, DERIVATIVE_FUNCTION_CONFIGURATIONS); BLOCK_RECORD(index_block, ENTRY_POINT); BLOCK_RECORD(index_block, LOCAL_DECL_CONTEXT_OFFSETS); BLOCK_RECORD(index_block, GENERIC_SIGNATURE_OFFSETS); @@ -4781,6 +4782,98 @@ static void writeObjCMethodTable(const index_block::ObjCMethodTableLayout &out, out.emit(scratch, tableOffset, hashTableBlob); } +namespace { + /// Used to serialize derivative function configurations. + class DerivativeFunctionConfigTableInfo { + public: + using key_type = std::string; + using key_type_ref = StringRef; + using data_type = Serializer::DerivativeFunctionConfigTableData; + using data_type_ref = const data_type &; + using hash_value_type = uint32_t; + using offset_type = unsigned; + + hash_value_type ComputeHash(key_type_ref key) { + assert(!key.empty()); + return llvm::djbHash(key, SWIFTMODULE_HASH_SEED); + } + + std::pair EmitKeyDataLength(raw_ostream &out, + key_type_ref key, + data_type_ref data) { + uint32_t keyLength = key.str().size(); + assert(keyLength == static_cast(keyLength)); + uint32_t dataLength = (sizeof(uint32_t) * 2) * data.size(); + for (auto entry : data) + dataLength += entry.first.size(); + assert(dataLength == static_cast(dataLength)); + endian::Writer writer(out, little); + writer.write(keyLength); + writer.write(dataLength); + return { keyLength, dataLength }; + } + + void EmitKey(raw_ostream &out, key_type_ref key, unsigned len) { + out << key; + } + + void EmitData(raw_ostream &out, key_type_ref key, data_type_ref data, + unsigned len) { + static_assert(declIDFitsIn32Bits(), "DeclID too large"); + endian::Writer writer(out, little); + for (auto &entry : data) { + // Write `GenericSignatureID`. + writer.write(entry.second); + // Write parameter indices string size, followed by data. + writer.write(entry.first.size()); + out << entry.first; + } + } + }; +} // end anonymous namespace + +static void writeDerivativeFunctionConfigs( + Serializer &S, const index_block::DerivativeFunctionConfigTableLayout &out, + Serializer::DerivativeFunctionConfigTable &derivativeConfigs) { + // Create the on-disk hash table. + llvm::OnDiskChainedHashTableGenerator + generator; + llvm::SmallString<32> hashTableBlob; + uint32_t tableOffset; + { + llvm::raw_svector_ostream blobStream(hashTableBlob); + for (auto &entry : derivativeConfigs) + generator.insert(entry.first.get(), entry.second); + // Make sure that no bucket is at offset 0. + endian::write(blobStream, 0, little); + tableOffset = generator.Emit(blobStream); + } + SmallVector scratch; + out.emit(scratch, tableOffset, hashTableBlob); +} + +// Records derivative function configurations for the given AbstractFunctionDecl +// by visiting `@differentiable` and `@derivative` attributes. +static void recordDerivativeFunctionConfig( + Serializer &S, const AbstractFunctionDecl *AFD, + Serializer::UniquedDerivativeFunctionConfigTable &derivativeConfigs) { + auto &ctx = AFD->getASTContext(); + Mangle::ASTMangler Mangler; + for (auto *attr : AFD->getAttrs().getAttributes()) { + auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(AFD, "")); + derivativeConfigs[mangledName].insert( + {ctx.getIdentifier(attr->getParameterIndices()->getString()), + attr->getDerivativeGenericSignature()}); + } + for (auto *attr : AFD->getAttrs().getAttributes()) { + auto *origAFD = attr->getOriginalFunction(); + auto mangledName = ctx.getIdentifier(Mangler.mangleDeclAsUSR(origAFD, "")); + derivativeConfigs[mangledName].insert( + {ctx.getIdentifier(attr->getParameterIndices()->getString()), + AFD->getGenericSignature()}); + } +}; + /// Recursively walks the members and derived global decls of any nominal types /// to build up global tables. template @@ -4790,6 +4883,7 @@ static void collectInterestingNestedDeclarations( Serializer::DeclTable &operatorMethodDecls, Serializer::ObjCMethodTable &objcMethods, Serializer::NestedTypeDeclsTable &nestedTypeDecls, + Serializer::UniquedDerivativeFunctionConfigTable &derivativeConfigs, bool isLocal = false) { const NominalTypeDecl *nominalParent = nullptr; @@ -4826,14 +4920,17 @@ static void collectInterestingNestedDeclarations( } } - // Record Objective-C methods. - if (auto *func = dyn_cast(member)) + // Record Objective-C methods and derivative function configurations. + if (auto *func = dyn_cast(member)) { recordObjCMethod(func); + recordDerivativeFunctionConfig(S, func, derivativeConfigs); + } // Handle accessors. if (auto storage = dyn_cast(member)) { for (auto *accessor : storage->getAllAccessors()) { recordObjCMethod(accessor); + recordDerivativeFunctionConfig(S, accessor, derivativeConfigs); } } @@ -4856,6 +4953,7 @@ static void collectInterestingNestedDeclarations( collectInterestingNestedDeclarations(S, iterable->getMembers(), operatorMethodDecls, objcMethods, nestedTypeDecls, + derivativeConfigs, isLocal); } } @@ -4868,6 +4966,7 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { NestedTypeDeclsTable nestedTypeDecls; LocalTypeHashTableGenerator localTypeGenerator, opaqueReturnTypeGenerator; ExtensionTable extensionDecls; + UniquedDerivativeFunctionConfigTable uniquedDerivativeConfigs; bool hasLocalTypes = false; bool hasOpaqueReturnTypes = false; @@ -4916,6 +5015,8 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { } else { llvm_unreachable("all top-level declaration kinds accounted for"); } + if (auto *AFD = dyn_cast(D)) + recordDerivativeFunctionConfig(*this, AFD, uniquedDerivativeConfigs); orderedTopLevelDecls.push_back(addDeclRef(D)); @@ -4925,7 +5026,8 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { if (auto IDC = dyn_cast(D)) { collectInterestingNestedDeclarations(*this, IDC->getMembers(), operatorMethodDecls, objcMethods, - nestedTypeDecls); + nestedTypeDecls, + uniquedDerivativeConfigs); } } @@ -4954,7 +5056,9 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { if (auto IDC = dyn_cast(TD)) { collectInterestingNestedDeclarations(*this, IDC->getMembers(), operatorMethodDecls, objcMethods, - nestedTypeDecls, /*isLocal=*/true); + nestedTypeDecls, + uniquedDerivativeConfigs, + /*isLocal=*/true); } } @@ -5017,6 +5121,20 @@ void Serializer::writeAST(ModuleOrSourceFile DC) { writeNestedTypeDeclsTable(NestedTypeDeclsTable, nestedTypeDecls); } + // Convert uniqued derivative function config table to serialization- + // ready format: turn `GenericSignature` to `GenericSignatureID`. + DerivativeFunctionConfigTable derivativeConfigs; + for (auto entry : uniquedDerivativeConfigs) { + for (auto config : entry.second) { + auto paramIndices = config.first.str(); + auto genSigID = addGenericSignatureRef(config.second); + derivativeConfigs[entry.first].push_back({paramIndices, genSigID}); + } + } + index_block::DerivativeFunctionConfigTableLayout DerivativeConfigTable(Out); + writeDerivativeFunctionConfigs(*this, DerivativeConfigTable, + derivativeConfigs); + if (entryPointClassID.hasValue()) { index_block::EntryPointLayout EntryPoint(Out); EntryPoint.emit(ScratchRecord, entryPointClassID.getValue()); diff --git a/lib/Serialization/Serialization.h b/lib/Serialization/Serialization.h index a570ca379895a..8529620ae9957 100644 --- a/lib/Serialization/Serialization.h +++ b/lib/Serialization/Serialization.h @@ -256,6 +256,23 @@ class Serializer : public SerializerBase { SmallVector, 4>; using ExtensionTable = llvm::MapVector; + using DerivativeFunctionConfigTableData = + llvm::SmallVector, 4>; + // In-memory representation of what will eventually be an on-disk hash table + // mapping original declaration USRs to derivative function configurations. + using DerivativeFunctionConfigTable = + llvm::MapVector; + // Uniqued mapping from original declarations USRs to derivative function + // configurations. + // Note: this exists because `GenericSignature` can be used as a `DenseMap` + // key, while `GenericSignatureID` cannot + // (`DenseMapInfo::getEmptyKey()` crashes). To work + // around this, a `UniquedDerivativeFunctionConfigTable` is first + // constructed, and then converted to a `DerivativeFunctionConfigTableData`. + using UniquedDerivativeFunctionConfigTable = llvm::MapVector< + Identifier, + llvm::SmallSetVector, 4>>; + private: /// A map from identifiers to methods and properties with the given name. /// diff --git a/lib/Serialization/SerializedModuleLoader.cpp b/lib/Serialization/SerializedModuleLoader.cpp index 64081f34d4edf..33ca6fe7d5d9b 100644 --- a/lib/Serialization/SerializedModuleLoader.cpp +++ b/lib/Serialization/SerializedModuleLoader.cpp @@ -993,6 +993,17 @@ void SerializedModuleLoaderBase::loadObjCMethods( } } +void SerializedModuleLoaderBase::loadDerivativeFunctionConfigurations( + AbstractFunctionDecl *originalAFD, unsigned int previousGeneration, + llvm::SetVector &results) { + for (auto &modulePair : LoadedModuleFiles) { + if (modulePair.second <= previousGeneration) + continue; + modulePair.first->loadDerivativeFunctionConfigurations(originalAFD, + results); + } +} + std::error_code MemoryBufferSerializedModuleLoader::findModuleFilesInDirectory( AccessPathElem ModuleID, const SerializedModuleBaseName &BaseName,