From ba9bfaf5d55cef90db981d23de07b69e45baf3fb Mon Sep 17 00:00:00 2001 From: Erik Eckstein Date: Mon, 12 Aug 2024 18:03:49 +0200 Subject: [PATCH] VTableSpecializer: fix a crash for methods which have their own generic parameters rdar://133334324 --- .../Transforms/VTableSpecializer.cpp | 34 ++++++++++++- .../classes-non-final-method-no-stdlib.swift | 20 ++++++++ test/embedded/generic-classes.swift | 51 +++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/lib/SILOptimizer/Transforms/VTableSpecializer.cpp b/lib/SILOptimizer/Transforms/VTableSpecializer.cpp index e2913b4d6f055..d5db8c63ed0a6 100644 --- a/lib/SILOptimizer/Transforms/VTableSpecializer.cpp +++ b/lib/SILOptimizer/Transforms/VTableSpecializer.cpp @@ -150,6 +150,26 @@ static bool specializeVTablesOfSuperclasses(SILModule &module, return changed; } +static SubstitutionMap getMethodSubs(SILFunction *method, SubstitutionMap classContextSubs) { + GenericSignature genericSig = + method->getLoweredFunctionType()->getInvocationGenericSignature(); + + if (!genericSig || genericSig->areAllParamsConcrete()) + return SubstitutionMap(); + + return SubstitutionMap::get(genericSig, + QuerySubstitutionMap{classContextSubs}, + LookUpConformanceInModule()); +} + +static bool hasInvalidConformance(SubstitutionMap subs) { + for (auto substConf : subs.getConformances()) { + if (substConf.isInvalid()) + return true; + } + return false; +} + SILVTable *swift::specializeVTableForType(SILType classTy, SILModule &module, SILTransform *transform) { CanType astType = classTy.getASTType(); @@ -182,8 +202,20 @@ SILVTable *swift::specializeVTableForType(SILType classTy, SILModule &module, for (const SILVTableEntry &entry : origVtable->getEntries()) { SILFunction *origMethod = entry.getImplementation(); + + auto methodSubs = getMethodSubs(origMethod, subs); + + // If the resulting substitution map is not valid this means that the method + // itself has generic parameters. + if (hasInvalidConformance(methodSubs)) { + module.getASTContext().Diags.diagnose( + entry.getMethod().getDecl()->getLoc(), diag::non_final_generic_class_function); + continue; + } + SILFunction *specializedMethod = - specializeVTableMethod(origMethod, subs, module, transform); + specializeVTableMethod(origMethod, methodSubs, module, transform); + newEntries.push_back(SILVTableEntry(entry.getMethod(), specializedMethod, entry.getKind(), entry.isNonOverridden())); diff --git a/test/embedded/classes-non-final-method-no-stdlib.swift b/test/embedded/classes-non-final-method-no-stdlib.swift index 96a0ac1512f9c..aed763d70844b 100644 --- a/test/embedded/classes-non-final-method-no-stdlib.swift +++ b/test/embedded/classes-non-final-method-no-stdlib.swift @@ -6,3 +6,23 @@ public class MyClass { func foo(t: T) { } // expected-error {{classes cannot have non-final generic fuctions in embedded Swift}} func bar() { } } + +final class C2 { + // TODO: this shouldn't be a problem because the class is final + init(x: T) { } // expected-error {{classes cannot have non-final generic fuctions in embedded Swift}} +} + +struct S {} + +func testit2() -> C2 { + return C2(x: S()) +} + +open class C3 { + public func foo(t: T) {} // expected-error {{classes cannot have non-final generic fuctions in embedded Swift}} +} + +func testit3() -> C3 { + return C3() +} + diff --git a/test/embedded/generic-classes.swift b/test/embedded/generic-classes.swift index 742050abe6282..8bfc1e31ae3e5 100644 --- a/test/embedded/generic-classes.swift +++ b/test/embedded/generic-classes.swift @@ -46,6 +46,52 @@ public func makeInner() -> Outer.Inner { return Outer.Inner() } +final class List where Element: ~Copyable { + init(x: Element) where Element: Copyable { } +} + +func testList() -> List { + return List(x: 0) +} + +open class OpenClass where Element: ~Copyable { + public func foo(x: Element) where Element: Copyable { } +} + +func testOpenClass() -> OpenClass { + return OpenClass() +} + + +class Base { + func foo(_: T) {} +} + +class Derived: Base> {} + +func testBaseDerived() -> Derived { + return Derived() +} + +class Base2 { + func foo(_: T) {} +} + +class Derived2: Base2<(T, T)> {} + +func testBaseDerived2() -> Derived2 { + return Derived2() +} + +class Base3 { + func foo(_: T) {} +} +class Derived3: Base3<(T, U)> {} + +func testBaseDerived3() -> Derived3 { + return Derived3() +} + @main struct Main { static func main() { @@ -56,6 +102,11 @@ struct Main { let x = SubClass2() x.test() makeInner().foo() + testList() + testOpenClass() + testBaseDerived() + testBaseDerived2() + testBaseDerived3() } }