diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8bd30e6f404a..3aacd6659b79 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1135,6 +1135,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case elsep: untpd.If => isIncomplete(elsep) case _ => false + // Insert a GADT cast if the type of the branch does not conform + // to the type assigned to the whole if tree. + // This happens when the computation of the type of the if tree + // uses GADT constraints. See #15646. + def gadtAdaptBranch(tree: Tree, branchPt: Type): Tree = + TypeComparer.testSubType(tree.tpe.widenExpr, branchPt) match { + case CompareResult.OKwithGADTUsed => + insertGadtCast(tree, tree.tpe.widen, branchPt) + case _ => tree + } + val branchPt = if isIncomplete(tree) then defn.UnitType else pt.dropIfProto val result = @@ -1148,7 +1159,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val elsep0 = typed(tree.elsep, branchPt)(using cond1.nullableContextIf(false)) thenp0 :: elsep0 :: Nil }: @unchecked - assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1) + + val resType = thenp1.tpe | elsep1.tpe + val thenp2 :: elsep2 :: Nil = + (thenp1 :: elsep1 :: Nil) map { t => + // Adapt each branch to ensure that their types conforms to the + // type assigned to the if tree by inserting GADT casts. + gadtAdaptBranch(t, resType) + }: @unchecked + + cpy.If(tree)(cond1, thenp2, elsep2).withType(resType) def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo) def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo) @@ -3763,20 +3783,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer gadts.println(i"unnecessary GADTused for $tree: ${tree.tpe.widenExpr} vs $pt in ${ctx.source}") res } => - // Insert an explicit cast, so that -Ycheck in later phases succeeds. - // The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts. - val target = - if tree.tpe.isSingleton then - val conj = AndType(tree.tpe, pt) - if tree.tpe.isStable && !conj.isStable then - // this is needed for -Ycheck. Without the annotation Ycheck will - // skolemize the result type which will lead to different types before - // and after checking. See i11955.scala. - AnnotatedType(conj, Annotation(defn.UncheckedStableAnnot)) - else conj - else pt - gadts.println(i"insert GADT cast from $tree to $target") - tree.cast(target) + insertGadtCast(tree, wtp, pt) case _ => //typr.println(i"OK ${tree.tpe}\n${TypeComparer.explained(_.isSubType(tree.tpe, pt))}") // uncomment for unexpected successes tree @@ -4207,4 +4214,36 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer EmptyTree else typedExpr(call, defn.AnyType) + /** Insert GADT cast to target type `pt` on the `tree` + * so that -Ycheck in later phases succeeds. + * The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts. + */ + private def insertGadtCast(tree: Tree, wtp: Type, pt: Type)(using Context): Tree = + val target = + if tree.tpe.isSingleton then + // In the target type, when the singleton type is intersected, we also intersect + // the GADT-approximated type of the singleton to avoid the loss of + // information. See #15646. + val gadtApprox = Inferencing.approximateGADT(wtp) + gadts.println(i"gadt approx $wtp ~~~ $gadtApprox") + val conj = + TypeComparer.testSubType(gadtApprox, pt) match { + case CompareResult.OK => + // GADT approximation of the tree type is a subtype of expected type under empty GADT + // constraints, so it is enough to only have the GADT approximation. + AndType(tree.tpe, gadtApprox) + case _ => + // In other cases, we intersect both the approximated type and the expected type. + AndType(AndType(tree.tpe, gadtApprox), pt) + } + if tree.tpe.isStable && !conj.isStable then + // this is needed for -Ycheck. Without the annotation Ycheck will + // skolemize the result type which will lead to different types before + // and after checking. See i11955.scala. + AnnotatedType(conj, Annotation(defn.UncheckedStableAnnot)) + else conj + else pt + gadts.println(i"insert GADT cast from $tree to $target") + tree.cast(target) + end insertGadtCast } diff --git a/tests/pos/gadt-cast-if.scala b/tests/pos/gadt-cast-if.scala new file mode 100644 index 000000000000..02a02f040cc1 --- /dev/null +++ b/tests/pos/gadt-cast-if.scala @@ -0,0 +1,12 @@ +trait Expr[T] + case class IntExpr() extends Expr[Int] + + def flag: Boolean = ??? + + def foo[T](ev: Expr[T]): Int | T = ev match + case IntExpr() => + if flag then + val i: T = ??? + i + else + (??? : Int) diff --git a/tests/pos/gadt-cast-singleton.scala b/tests/pos/gadt-cast-singleton.scala new file mode 100644 index 000000000000..57cd2bc8f578 --- /dev/null +++ b/tests/pos/gadt-cast-singleton.scala @@ -0,0 +1,13 @@ +enum SUB[-A, +B]: + case Refl[S]() extends SUB[S, S] + +trait R { + type Data +} +trait L extends R + +def f(x: L): x.Data = ??? + +def g[T <: R](x: T, ev: T SUB L): x.Data = ev match + case SUB.Refl() => + f(x) diff --git a/tests/pos/i14776-patmat.scala b/tests/pos/i14776-patmat.scala new file mode 100644 index 000000000000..570e8cc64bac --- /dev/null +++ b/tests/pos/i14776-patmat.scala @@ -0,0 +1,15 @@ +trait T1 +trait T2 extends T1 + +trait Expr[T] { val data: T = ??? } +case class Tag2() extends Expr[T2] + +def flag: Boolean = ??? + +def foo[T](e: Expr[T]): T1 = e match { + case Tag2() => + flag match + case true => new T2 {} + case false => e.data +} + diff --git a/tests/pos/i14776.scala b/tests/pos/i14776.scala new file mode 100644 index 000000000000..262a3750ff73 --- /dev/null +++ b/tests/pos/i14776.scala @@ -0,0 +1,16 @@ +trait T1 +trait T2 extends T1 + +trait Expr[T] { val data: T = ??? } +case class Tag2() extends Expr[T2] + +def flag: Boolean = ??? + +def foo[T](e: Expr[T]): T1 = e match { + case Tag2() => + if flag then + new T2 {} + else + e.data +} +