diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index b0a085d596a0..2a9e643aa5b0 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -865,7 +865,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => /** An extractor for def of a closure contained the block of the closure. */ object closureDef { def unapply(tree: Tree)(using Context): Option[DefDef] = tree match { - case Block((meth : DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol => + case Block((meth: DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol => Some(meth) case Block(Nil, expr) => unapply(expr) diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 2abae103780f..991309293c0c 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -375,7 +375,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { * new parents { termForwarders; typeAliases } * * @param parents a non-empty list of class types - * @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to. + * @param termForwarders a non-empty list of forwarding definitions specified by their name + * and the definition they forward to. * @param typeMembers a possibly-empty list of type members specified by their name and their right hand side. * @param adaptVarargs if true, allow matching a vararg superclass constructor * with a missing argument in superArgs, and synthesize an diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index ea4626ac684e..124d96e3c733 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -6053,14 +6053,14 @@ object Types extends TypeUtils { def takesNoArgs(tp: Type) = !tp.classSymbol.primaryConstructor.exists // e.g. `ContextFunctionN` does not have constructors - || tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0 + || tp.applicableConstructors(argTypes = Nil, adaptVarargs = true).lengthCompare(1) == 0 // we require a unique constructor so that SAM expansion is deterministic val noArgsNeeded: Boolean = takesNoArgs(tp) - && (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head)) + && (!cls.is(Trait) || takesNoArgs(tp.parents.head)) def isInstantiable = - !tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType) - if noArgsNeeded && isInstantiable then tp.cls + !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType) + if noArgsNeeded && isInstantiable then cls else NoSymbol case tp: AppliedType => samClass(tp.superType) diff --git a/compiler/src/dotty/tools/dotc/transform/Dependencies.scala b/compiler/src/dotty/tools/dotc/transform/Dependencies.scala index 9084930b6815..a4c3550441f5 100644 --- a/compiler/src/dotty/tools/dotc/transform/Dependencies.scala +++ b/compiler/src/dotty/tools/dotc/transform/Dependencies.scala @@ -45,9 +45,9 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co /** A map from local methods and classes to the owners to which they will be lifted as members. * For methods and classes that do not have any dependencies this will be the enclosing package. - * symbols with packages as lifted owners will subsequently represented as static + * Symbols with packages as lifted owners will be subsequently represented as static * members of their toplevel class, unless their enclosing class was already static. - * Note: During tree transform (which runs at phase LambdaLift + 1), liftedOwner + * Note: During tree transform (which runs at phase LambdaLift + 1), logicOwner * is also used to decide whether a method had a term owner before. */ private val logicOwner = new LinkedHashMap[Symbol, Symbol] @@ -75,8 +75,8 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co || owner.is(Trait) && isLocal(owner) || sym.isConstructor && isLocal(owner) - /** Set `liftedOwner(sym)` to `owner` if `owner` is more deeply nested - * than the previous value of `liftedowner(sym)`. + /** Set `logicOwner(sym)` to `owner` if `owner` is more deeply nested + * than the previous value of `logicOwner(sym)`. */ private def narrowLogicOwner(sym: Symbol, owner: Symbol)(using Context): Unit = if sym.maybeOwner.isTerm @@ -89,7 +89,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co /** Mark symbol `sym` as being free in `enclosure`, unless `sym` is defined * in `enclosure` or there is an intermediate class properly containing `enclosure` - * in which `sym` is also free. Also, update `liftedOwner` of `enclosure` so + * in which `sym` is also free. Also, update `logicOwner` of `enclosure` so * that `enclosure` can access `sym`, or its proxy in an intermediate class. * This means: * @@ -284,7 +284,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co changedFreeVars do () - /** Compute final liftedOwner map by closing over caller dependencies */ + /** Compute final logicOwner map by closing over caller dependencies */ private def computeLogicOwners()(using Context): Unit = while changedLogicOwner = false diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index 68f911f06963..4288a82e0191 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -10,6 +10,7 @@ import Names.TypeName import NullOpsDecorator.* import ast.untpd +import scala.collection.mutable.ListBuffer /** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes. * These fall into five categories @@ -17,7 +18,7 @@ import ast.untpd * 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these. * 2. Closures implementing non-trait classes * 3. Closures implementing classes that inherit from a class other than Object - * (a lambda cannot not be a run-time subtype of such a class) + * (a lambda cannot be a run-time subtype of such a class) * 4. Closures that implement traits which run initialization code. * 5. Closures that get synthesized abstract methods in the transformation pipeline. These methods can be * (1) superaccessors, (2) outer references, (3) accessors for fields. @@ -59,7 +60,7 @@ class ExpandSAMs extends MiniPhase: // A SAM type is allowed to have type aliases refinements (see // SAMType#samParent) which must be converted into type members if // the closure is desugared into a class. - val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]() + val refinements = ListBuffer.empty[(TypeName, TypeAlias)] def collectAndStripRefinements(tp: Type): Type = tp match case RefinedType(parent, name, info: TypeAlias) => val res = collectAndStripRefinements(parent) @@ -81,34 +82,40 @@ class ExpandSAMs extends MiniPhase: tree } - /** A partial function literal: + /** A pattern-matching anonymous function: * * ``` * val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En } * ``` + * or + * ``` + * x => e(x) { case C1 => E1; ...; case Cn => En } + * ``` + * where the expression `e(x)` may be trivially `x` * * which desugars to: * * ``` * val x: PartialFunction[A, B] = { - * def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En } + * def $anonfun(x: A): B = e(x) match { case C1 => E1; ...; case Cn => En } * closure($anonfun: PartialFunction[A, B]) * } * ``` + * where the expression `e(x)` defaults to `x` for a simple block of cases * * is expanded to an anonymous class: * * ``` * val x: PartialFunction[A, B] = { * class $anon extends AbstractPartialFunction[A, B] { - * final def isDefinedAt(x: A): Boolean = x match { + * final def isDefinedAt(x: A): Boolean = e(x) match { * case C1 => true * ... * case Cn => true * case _ => false * } * - * final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match { + * final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = e(x) match { * case C1 => E1 * ... * case Cn => En @@ -120,7 +127,7 @@ class ExpandSAMs extends MiniPhase: * } * ``` */ - private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = { + private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree: @unchecked // The right hand side from which to construct the partial function. This is always a Match. @@ -146,7 +153,7 @@ class ExpandSAMs extends MiniPhase: defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType), defn.SerializableType) - AnonClass(anonSym.owner, parents, tree.span) { pfSym => + AnonClass(anonSym.owner, parents, tree.span): pfSym => def overrideSym(sym: Symbol) = sym.copy( owner = pfSym, flags = Synthetic | Method | Final | Override, @@ -155,7 +162,8 @@ class ExpandSAMs extends MiniPhase: val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) - def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = { + def translateMatch(owner: Symbol)(pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = + val tree: Match = pfRHS val selector = tree.selector val cases1 = if cases.exists(isDefaultCase) then cases else @@ -165,31 +173,27 @@ class ExpandSAMs extends MiniPhase: cases :+ defaultCase cpy.Match(tree)(selector, cases1) .subst(param.symbol :: Nil, pfParam :: Nil) - // Needed because a partial function can be written as: + // Needed because a partial function can be written as: // param => param match { case "foo" if foo(param) => param } // And we need to update all references to 'param' - } + .changeOwner(anonSym, owner) - def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = { + def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = val tru = Literal(Constant(true)) - def translateCase(cdef: CaseDef) = - cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) + def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru) val paramRef = paramRefss.head.head val defaultValue = Literal(Constant(false)) - translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue) - } + translateMatch(isDefinedAtFn)(paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue) - def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = { + def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = val List(paramRef, defaultRef) = paramRefss(1) - def translateCase(cdef: CaseDef) = - cdef.changeOwner(anonSym, applyOrElseFn) val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef) - translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue) - } + translateMatch(applyOrElseFn)(paramRef.symbol, pfRHS.cases, defaultValue) - val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn)))) - val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn)))) + val isDefinedAtDef = transformFollowingDeep: + DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))) + val applyOrElseDef = transformFollowingDeep: + DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))) List(isDefinedAtDef, applyOrElseDef) - } - } + end toPartialFunction end ExpandSAMs diff --git a/docs/_spec/08-pattern-matching.md b/docs/_spec/08-pattern-matching.md index 4a34ae8631c4..c13aa5d44b89 100644 --- a/docs/_spec/08-pattern-matching.md +++ b/docs/_spec/08-pattern-matching.md @@ -616,7 +616,7 @@ new scala.PartialFunction[´S´, ´T´] { def apply(´x´: ´S´): ´T´ = x match { case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´ } - def isDefinedAt(´x´: ´S´): Boolean = { + def isDefinedAt(´x´: ´S´): Boolean = x match { case ´p_1´ => true ... case ´p_n´ => true case _ => false } @@ -626,6 +626,22 @@ new scala.PartialFunction[´S´, ´T´] { Here, ´x´ is a fresh name and ´T´ is the least upper bound of the types of all ´b_i´. The final default case in the `isDefinedAt` method is omitted if one of the patterns ´p_1, ..., p_n´ is already a variable or wildcard pattern. +As a convenience, the partial function may be written using function literal notation: + +```scala +(´x: S´) => e(´x´) match { + case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´ +} +``` +where the selector expression is used for matches in the expansion. +The body of the function must consist solely of the match expression. + +This syntax permits annotating the selector: + +```scala +(´x: S´) => (e(´x´): @unchecked) match { ... } +``` + ###### Example Here's an example which uses `foldLeft` to compute the scalar product of two vectors: diff --git a/tests/pos/i23025.scala b/tests/pos/i23025.scala new file mode 100644 index 000000000000..23a2eb78613d --- /dev/null +++ b/tests/pos/i23025.scala @@ -0,0 +1,5 @@ + +class A { + def f: PartialFunction[Int, Int] = + a => { (try a catch { case e : Throwable => throw e}) match { case n => n } } +} diff --git a/tests/pos/i23054.scala b/tests/pos/i23054.scala new file mode 100644 index 000000000000..8d5bf9de12fb --- /dev/null +++ b/tests/pos/i23054.scala @@ -0,0 +1,15 @@ + +object Bug: + + def m0(f: PartialFunction[Char, Unit]): Unit = () + + def m1(): Unit = + m0: x => + "abc".filter(_ == x) match + case _ => () + + def m2(): Unit = + m0: x => + x match + case _ => () + diff --git a/tests/pos/i23310.scala b/tests/pos/i23310.scala new file mode 100644 index 000000000000..3b2c7eda9141 --- /dev/null +++ b/tests/pos/i23310.scala @@ -0,0 +1,16 @@ + +object Example { + val pf: PartialFunction[Unit, Unit] = s => (s match { + case a => a + }) match { + case a => () + } +} + +object ExampleB: + def test = + List(42).collect: + _.match + case x => x + .match + case y => y + 27