diff --git a/include/swift/AST/GenericParamList.h b/include/swift/AST/GenericParamList.h index ac9730cdeb07a..48a529101cacb 100644 --- a/include/swift/AST/GenericParamList.h +++ b/include/swift/AST/GenericParamList.h @@ -203,8 +203,6 @@ class RequirementRepr { void print(ASTPrinter &Printer) const; }; -using GenericParamSource = PointerUnion; - /// GenericParamList - A list of generic parameters that is part of a generic /// function or type, along with extra requirements placed on those generic /// parameters and types derived from them. diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index ad7576b502a49..b346f1515d57d 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -459,6 +459,8 @@ struct WhereClauseOwner { SpecializeAttr *, DifferentiableAttr *> source; + WhereClauseOwner() : dc(nullptr) {} + WhereClauseOwner(GenericContext *genCtx); WhereClauseOwner(AssociatedTypeDecl *atd); @@ -480,6 +482,10 @@ struct WhereClauseOwner { return llvm::hash_value(owner.source.getOpaqueValue()); } + operator bool() const { + return dc != nullptr; + } + friend bool operator==(const WhereClauseOwner &lhs, const WhereClauseOwner &rhs) { return lhs.source.getOpaqueValue() == rhs.source.getOpaqueValue(); @@ -1437,11 +1443,12 @@ class AbstractGenericSignatureRequest : class InferredGenericSignatureRequest : public SimpleRequest, - SmallVector, - bool), + const GenericSignatureImpl *, + GenericParamList *, + WhereClauseOwner, + SmallVector, + SmallVector, + bool), RequestFlags::Cached> { public: using SimpleRequest::SimpleRequest; @@ -1452,9 +1459,10 @@ class InferredGenericSignatureRequest : // Evaluation. GenericSignature evaluate(Evaluator &evaluator, - ModuleDecl *module, + ModuleDecl *parentModule, const GenericSignatureImpl *baseSignature, - GenericParamSource paramSource, + GenericParamList *genericParams, + WhereClauseOwner whereClause, SmallVector addedRequirements, SmallVector inferenceSources, bool allowConcreteGenericParams) const; diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index fa78ec69af88f..d15a453b96797 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -131,10 +131,12 @@ SWIFT_REQUEST(TypeChecker, HasImplementationOnlyImportsRequest, SWIFT_REQUEST(TypeChecker, ModuleLibraryLevelRequest, LibraryLevel(ModuleDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest, - GenericSignature (ModuleDecl *, const GenericSignatureImpl *, - GenericParamSource, - SmallVector, - SmallVector, bool), + GenericSignature (ModuleDecl *, + const GenericSignatureImpl *, + GenericParamList *, + WhereClauseOwner, + SmallVector, + SmallVector, bool), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest, bool(ModuleDecl *), Cached, NoLocationInfo) diff --git a/include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h b/include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h index de08010d64561..209aec0fdc252 100644 --- a/include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h +++ b/include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h @@ -49,7 +49,6 @@ #define SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_ #include "swift/AST/GenericEnvironment.h" -#include "swift/AST/GenericSignatureBuilder.h" #include "swift/SIL/SILFunction.h" #include "swift/SIL/SILModule.h" #include "swift/SIL/SILValue.h" diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp index ecd154a7fcf16..bd8ad3a2b24be 100644 --- a/lib/AST/GenericSignatureBuilder.cpp +++ b/lib/AST/GenericSignatureBuilder.cpp @@ -8679,9 +8679,11 @@ AbstractGenericSignatureRequest::evaluate( GenericSignature InferredGenericSignatureRequest::evaluate( - Evaluator &evaluator, ModuleDecl *parentModule, + Evaluator &evaluator, + ModuleDecl *parentModule, const GenericSignatureImpl *parentSig, - GenericParamSource paramSource, + GenericParamList *genericParams, + WhereClauseOwner whereClause, SmallVector addedRequirements, SmallVector inferenceSources, bool allowConcreteGenericParams) const { @@ -8729,12 +8731,6 @@ InferredGenericSignatureRequest::evaluate( return false; }; - GenericParamList *genericParams = nullptr; - if (auto params = paramSource.dyn_cast()) - genericParams = params; - else - genericParams = paramSource.get()->getGenericParams(); - if (genericParams) { // Extensions never have a parent signature. if (genericParams->getOuterParameters()) @@ -8777,15 +8773,10 @@ InferredGenericSignatureRequest::evaluate( } } - if (auto *ctx = paramSource.dyn_cast()) { - // The declaration might have a trailing where clause. - if (auto *where = ctx->getTrailingWhereClause()) { - // Determine where and how to perform name lookup. - lookupDC = ctx; - - WhereClauseOwner(lookupDC, where).visitRequirements( + if (whereClause) { + lookupDC = whereClause.dc; + std::move(whereClause).visitRequirements( TypeResolutionStage::Structural, visitRequirement); - } } /// Perform any remaining requirement inference. diff --git a/lib/IRGen/GenType.h b/lib/IRGen/GenType.h index 1b485ba477d3d..e070dcef0b432 100644 --- a/lib/IRGen/GenType.h +++ b/lib/IRGen/GenType.h @@ -34,7 +34,6 @@ namespace llvm { } namespace swift { - class GenericSignatureBuilder; class ArchetypeType; class CanType; class ClassDecl; diff --git a/lib/IRGen/IRGenModule.h b/lib/IRGen/IRGenModule.h index 520ccd16c8b0c..48c1dab99bc13 100644 --- a/lib/IRGen/IRGenModule.h +++ b/lib/IRGen/IRGenModule.h @@ -85,7 +85,6 @@ namespace clang { namespace swift { class GenericSignature; - class GenericSignatureBuilder; class AssociatedConformance; class AssociatedType; class ASTContext; diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 41b6d4bbb61aa..cce750f6089d3 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -20,9 +20,9 @@ #include "swift/SILOptimizer/Differentiation/Common.h" #include "swift/AST/AnyFunctionRef.h" -#include "swift/AST/GenericSignatureBuilder.h" #include "swift/AST/Requirement.h" #include "swift/AST/SubstitutionMap.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" #include "swift/SILOptimizer/Utils/DifferentiationMangler.h" @@ -53,30 +53,26 @@ CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig, } auto &ctx = fn->getASTContext(); - GenericSignatureBuilder builder(ctx); // Add the existing generic signature. + GenericSignature baseGenericSig; int depth = 0; if (inheritGenericSig) { - if (auto genericSig = - fn->getLoweredFunctionType()->getSubstGenericSignature()) { - builder.addGenericSignature(genericSig); - depth = genericSig.getGenericParams().back()->getDepth() + 1; - } + baseGenericSig = fn->getLoweredFunctionType()->getSubstGenericSignature(); + if (baseGenericSig) + depth = baseGenericSig.getGenericParams().back()->getDepth() + 1; } // Add a new generic parameter to replace the opened existential. auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx); - - builder.addGenericParameter(newGenericParam); Requirement newRequirement(RequirementKind::Conformance, newGenericParam, openedExistential->getOpenedExistentialType()); - auto source = - GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); - builder.addRequirement(newRequirement, source, nullptr); - auto genericSig = std::move(builder).computeGenericSignature( - /*allowConcreteGenericParams=*/true); + auto genericSig = evaluateOrDefault( + ctx.evaluator, + AbstractGenericSignatureRequest{ + baseGenericSig.getPointer(), { newGenericParam }, { newRequirement }}, + GenericSignature()); genericEnv = genericSig.getGenericEnvironment(); newArchetype = diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 4a357e07f755c..abb901f8e6fb7 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -25,7 +25,6 @@ #include "swift/AST/DiagnosticsParse.h" #include "swift/AST/Effects.h" #include "swift/AST/GenericEnvironment.h" -#include "swift/AST/GenericSignatureBuilder.h" #include "swift/AST/ImportCache.h" #include "swift/AST/ModuleNameLookup.h" #include "swift/AST/NameLookup.h" @@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) { return; } - // Form a new generic signature based on the old one. - GenericSignatureBuilder Builder(D->getASTContext()); + InferredGenericSignatureRequest request{ + DC->getParentModule(), + genericSig.getPointer(), + /*genericParams=*/nullptr, + WhereClauseOwner(FD, attr), + /*addedRequirements=*/{}, + /*inferenceSources=*/{}, + /*allowConcreteGenericParams=*/true}; - // First, add the old generic signature. - Builder.addGenericSignature(genericSig); - - // Go over the set of requirements, adding them to the builder. - WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface, - [&](const Requirement &req, RequirementRepr *reqRepr) { - // Add the requirement to the generic signature builder. - using FloatingRequirementSource = - GenericSignatureBuilder::FloatingRequirementSource; - Builder.addRequirement(req, reqRepr, - FloatingRequirementSource::forExplicit( - reqRepr->getSeparatorLoc()), - nullptr, DC->getParentModule()); - return false; - }); - - // Check the result. - auto specializedSig = std::move(Builder).computeGenericSignature( - /*allowConcreteGenericParams=*/true); + auto specializedSig = evaluateOrDefault(Ctx.evaluator, request, + GenericSignature()); // Check the validity of provided requirements. checkSpecializeAttrRequirements(attr, genericSig, specializedSig, Ctx); @@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature( // - If the `@differentiable` attribute has a `where` clause, use it to // compute the derivative generic signature. // - Otherwise, use the original function's generic signature by default. - derivativeGenSig = original->getGenericSignature(); + auto originalGenSig = original->getGenericSignature(); + derivativeGenSig = originalGenSig; // Handle the `where` clause, if it exists. // - Resolve attribute where clause requirements and store in the attribute @@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature( return true; } - auto originalGenSig = original->getGenericSignature(); if (!originalGenSig) { // `where` clauses are valid only when the original function is generic. diags @@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature( return true; } - // Build a new generic signature for autodiff derivative functions. - GenericSignatureBuilder builder(ctx); - // Add the original function's generic signature. - builder.addGenericSignature(originalGenSig); - - using FloatingRequirementSource = - GenericSignatureBuilder::FloatingRequirementSource; - - bool errorOccurred = false; - WhereClauseOwner(original, attr) - .visitRequirements( - TypeResolutionStage::Structural, - [&](const Requirement &req, RequirementRepr *reqRepr) { - switch (req.getKind()) { - case RequirementKind::SameType: - case RequirementKind::Superclass: - case RequirementKind::Conformance: - break; - - // Layout requirements are not supported. - case RequirementKind::Layout: - diags - .diagnose(attr->getLocation(), - diag::differentiable_attr_layout_req_unsupported) - .highlight(reqRepr->getSourceRange()); - errorOccurred = true; - return false; - } + InferredGenericSignatureRequest request{ + original->getParentModule(), + originalGenSig.getPointer(), + /*genericParams=*/nullptr, + WhereClauseOwner(original, attr), + /*addedRequirements=*/{}, + /*inferenceSources=*/{}, + /*allowConcreteParams=*/true}; + + // Compute generic signature for derivative functions. + derivativeGenSig = evaluateOrDefault(ctx.evaluator, request, + GenericSignature()); - // Add requirement to generic signature builder. - builder.addRequirement( - req, reqRepr, FloatingRequirementSource::forExplicit( - reqRepr->getSeparatorLoc()), - nullptr, original->getModuleContext()); - return false; - }); + bool hadInvalidRequirements = false; + for (auto req : derivativeGenSig.requirementsNotSatisfiedBy(originalGenSig)) { + if (req.getKind() == RequirementKind::Layout) { + // Layout requirements are not supported. + diags + .diagnose(attr->getLocation(), + diag::differentiable_attr_layout_req_unsupported); + hadInvalidRequirements = true; + } + } - if (errorOccurred) { + if (hadInvalidRequirements) { attr->setInvalid(); return true; } - - // Compute generic signature for derivative functions. - derivativeGenSig = std::move(builder).computeGenericSignature( - /*allowConcreteGenericParams=*/true); } attr->setDerivativeGenericSignature(derivativeGenSig); diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp index 19f034be6693a..9ce364424d8eb 100644 --- a/lib/Sema/TypeCheckGeneric.cpp +++ b/lib/Sema/TypeCheckGeneric.cpp @@ -20,7 +20,6 @@ #include "swift/AST/DiagnosticsSema.h" #include "swift/AST/ExistentialLayout.h" #include "swift/AST/GenericEnvironment.h" -#include "swift/AST/GenericSignatureBuilder.h" #include "swift/AST/ParameterList.h" #include "swift/AST/ProtocolConformance.h" #include "swift/AST/TypeCheckRequests.h" @@ -517,44 +516,6 @@ void TypeChecker::checkReferencedGenericParams(GenericContext *dc) { /// Generic types /// -GenericSignature TypeChecker::checkGenericSignature( - GenericParamSource paramSource, - DeclContext *dc, - GenericSignature parentSig, - bool allowConcreteGenericParams, - SmallVector additionalRequirements, - SmallVector inferenceSources) { - if (auto genericParamList = paramSource.dyn_cast()) - assert(genericParamList && "Missing generic parameters?"); - - auto request = InferredGenericSignatureRequest{ - dc->getParentModule(), parentSig.getPointer(), paramSource, - additionalRequirements, inferenceSources, - allowConcreteGenericParams}; - auto sig = evaluateOrDefault(dc->getASTContext().evaluator, - request, nullptr); - - // Debugging of the generic signature builder and generic signature - // generation. - if (dc->getASTContext().TypeCheckerOpts.DebugGenericSignatures) { - llvm::errs() << "\n"; - if (auto *VD = dyn_cast_or_null(dc->getAsDecl())) { - VD->dumpRef(llvm::errs()); - llvm::errs() << "\n"; - } else { - dc->printContext(llvm::errs()); - } - llvm::errs() << "Generic signature: "; - sig->print(llvm::errs()); - llvm::errs() << "\n"; - llvm::errs() << "Canonical generic signature: "; - sig.getCanonicalSignature()->print(llvm::errs()); - llvm::errs() << "\n"; - } - - return sig; -} - /// Form the interface type of an extension from the raw type and the /// extension's list of generic parameters. static Type formExtensionInterfaceType( @@ -649,6 +610,8 @@ static unsigned getExtendedTypeGenericDepth(ExtensionDecl *ext) { GenericSignature GenericSignatureRequest::evaluate(Evaluator &evaluator, GenericContext *GC) const { + auto &ctx = GC->getASTContext(); + // The signature of a Protocol is trivial (Self: TheProtocol) so let's compute // it. if (auto PD = dyn_cast(GC)) { @@ -660,7 +623,7 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, // Debugging of the generic signature builder and generic signature // generation. - if (GC->getASTContext().TypeCheckerOpts.DebugGenericSignatures) { + if (ctx.TypeCheckerOpts.DebugGenericSignatures) { llvm::errs() << "\n"; PD->printContext(llvm::errs()); llvm::errs() << "Generic signature: "; @@ -698,10 +661,10 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, // If there is no generic context for the where clause to // rely on, diagnose that now and bail out. if (!GC->isGenericContext()) { - GC->getASTContext().Diags.diagnose(where->getWhereLoc(), - GC->getParent()->isModuleScopeContext() - ? diag::where_nongeneric_toplevel - : diag::where_nongeneric_ctx); + ctx.Diags.diagnose(where->getWhereLoc(), + GC->getParent()->isModuleScopeContext() + ? diag::where_nongeneric_toplevel + : diag::where_nongeneric_ctx); return nullptr; } } @@ -806,10 +769,33 @@ GenericSignatureRequest::evaluate(Evaluator &evaluator, inferenceSources.emplace_back(nullptr, extInterfaceType); } - return TypeChecker::checkGenericSignature( - GC, GC, parentSig, - allowConcreteGenericParams, - sameTypeReqs, inferenceSources); + auto request = InferredGenericSignatureRequest{ + GC->getParentModule(), parentSig.getPointer(), + GC->getGenericParams(), WhereClauseOwner(GC), + sameTypeReqs, inferenceSources, + allowConcreteGenericParams}; + auto sig = evaluateOrDefault(ctx.evaluator, + request, nullptr); + + // Debugging of the generic signature builder and generic signature + // generation. + if (ctx.TypeCheckerOpts.DebugGenericSignatures) { + llvm::errs() << "\n"; + if (auto *VD = dyn_cast_or_null(GC->getAsDecl())) { + VD->dumpRef(llvm::errs()); + llvm::errs() << "\n"; + } else { + GC->printContext(llvm::errs()); + } + llvm::errs() << "Generic signature: "; + sig->print(llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Canonical generic signature: "; + sig.getCanonicalSignature()->print(llvm::errs()); + llvm::errs() << "\n"; + } + + return sig; } /// diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp index 878c036f9be3d..01c3d3d062aef 100644 --- a/lib/Sema/TypeChecker.cpp +++ b/lib/Sema/TypeChecker.cpp @@ -437,10 +437,14 @@ swift::handleSILGenericParams(GenericParamList *genericParams, genericParams->walk(walker); } - return TypeChecker::checkGenericSignature(nestedList.back(), DC, - /*parentSig=*/nullptr, - /*allowConcreteGenericParams=*/true) - .getGenericEnvironment(); + auto request = InferredGenericSignatureRequest{ + DC->getParentModule(), /*parentSig=*/nullptr, + nestedList.back(), WhereClauseOwner(), + {}, {}, /*allowConcreteGenericParams=*/true}; + auto sig = evaluateOrDefault(DC->getASTContext().evaluator, + request, GenericSignature()); + + return sig.getGenericEnvironment(); } void swift::typeCheckPatternBinding(PatternBindingDecl *PBD, diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 043bf2abf4c87..f5af2994a762c 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -422,34 +422,6 @@ void checkProtocolSelfRequirements(ValueDecl *decl); /// declaration's type, otherwise we have no way to infer them. void checkReferencedGenericParams(GenericContext *dc); -/// Construct a new generic environment for the given declaration context. -/// -/// \param paramSource The source of generic info: either a generic parameter -/// list or a generic context with a \c where clause dependent on outer -/// generic parameters. -/// -/// \param dc The declaration context in which to perform the validation. -/// -/// \param outerSignature The generic signature of the outer -/// context, if not available as part of the \c dc argument (used -/// for SIL parsing). -/// -/// \param allowConcreteGenericParams Whether or not to allow -/// same-type constraints between generic parameters and concrete types. -/// -/// \param additionalRequirements Additional requirements to add -/// directly to the GSB. -/// -/// \param inferenceSources Additional types to infer requirements from. -/// -/// \returns the resulting generic signature. -GenericSignature -checkGenericSignature(GenericParamSource paramSource, DeclContext *dc, - GenericSignature outerSignature, - bool allowConcreteGenericParams, - SmallVector additionalRequirements = {}, - SmallVector inferenceSources = {}); - /// Create a text string that describes the bindings of generic parameters /// that are relevant to the given set of types, e.g., /// "[with T = Bar, U = Wibble]". diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 0ea9361946418..5aed01126a500 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -191,6 +191,7 @@ func invalidRequirementConformance(x: Scalar) -> Scalar { return x } +// expected-error @+1 {{'@differentiable' attribute does not yet support layout requirements}} @differentiable(reverse where T: AnyObject) func invalidAnyObjectRequirement(x: T) -> T { return x