Skip to content

Partial function synthesis changesOwner of selector #23337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
/** An extractor for def of a closure contained the block of the closure. */
object closureDef {
def unapply(tree: Tree)(using Context): Option[DefDef] = tree match {
case Block((meth : DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol =>
case Block((meth: DefDef) :: Nil, closure: Closure) if meth.symbol == closure.meth.symbol =>
Some(meth)
case Block(Nil, expr) =>
unapply(expr)
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* new parents { termForwarders; typeAliases }
*
* @param parents a non-empty list of class types
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
* @param termForwarders a non-empty list of forwarding definitions specified by their name
* and the definition they forward to.
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
* @param adaptVarargs if true, allow matching a vararg superclass constructor
* with a missing argument in superArgs, and synthesize an
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6053,14 +6053,14 @@ object Types extends TypeUtils {
def takesNoArgs(tp: Type) =
!tp.classSymbol.primaryConstructor.exists
// e.g. `ContextFunctionN` does not have constructors
|| tp.applicableConstructors(Nil, adaptVarargs = true).lengthCompare(1) == 0
|| tp.applicableConstructors(argTypes = Nil, adaptVarargs = true).lengthCompare(1) == 0
// we require a unique constructor so that SAM expansion is deterministic
val noArgsNeeded: Boolean =
takesNoArgs(tp)
&& (!tp.cls.is(Trait) || takesNoArgs(tp.parents.head))
&& (!cls.is(Trait) || takesNoArgs(tp.parents.head))
def isInstantiable =
!tp.cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if noArgsNeeded && isInstantiable then tp.cls
!cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
if noArgsNeeded && isInstantiable then cls
else NoSymbol
case tp: AppliedType =>
samClass(tp.superType)
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co

/** A map from local methods and classes to the owners to which they will be lifted as members.
* For methods and classes that do not have any dependencies this will be the enclosing package.
* symbols with packages as lifted owners will subsequently represented as static
* Symbols with packages as lifted owners will be subsequently represented as static
* members of their toplevel class, unless their enclosing class was already static.
* Note: During tree transform (which runs at phase LambdaLift + 1), liftedOwner
* Note: During tree transform (which runs at phase LambdaLift + 1), logicOwner
* is also used to decide whether a method had a term owner before.
*/
private val logicOwner = new LinkedHashMap[Symbol, Symbol]
Expand Down Expand Up @@ -75,8 +75,8 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
|| owner.is(Trait) && isLocal(owner)
|| sym.isConstructor && isLocal(owner)

/** Set `liftedOwner(sym)` to `owner` if `owner` is more deeply nested
* than the previous value of `liftedowner(sym)`.
/** Set `logicOwner(sym)` to `owner` if `owner` is more deeply nested
* than the previous value of `logicOwner(sym)`.
*/
private def narrowLogicOwner(sym: Symbol, owner: Symbol)(using Context): Unit =
if sym.maybeOwner.isTerm
Expand All @@ -89,7 +89,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co

/** Mark symbol `sym` as being free in `enclosure`, unless `sym` is defined
* in `enclosure` or there is an intermediate class properly containing `enclosure`
* in which `sym` is also free. Also, update `liftedOwner` of `enclosure` so
* in which `sym` is also free. Also, update `logicOwner` of `enclosure` so
* that `enclosure` can access `sym`, or its proxy in an intermediate class.
* This means:
*
Expand Down Expand Up @@ -284,7 +284,7 @@ abstract class Dependencies(root: ast.tpd.Tree, @constructorOnly rootContext: Co
changedFreeVars
do ()

/** Compute final liftedOwner map by closing over caller dependencies */
/** Compute final logicOwner map by closing over caller dependencies */
private def computeLogicOwners()(using Context): Unit =
while
changedLogicOwner = false
Expand Down
54 changes: 29 additions & 25 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import Names.TypeName

import NullOpsDecorator.*
import ast.untpd
import scala.collection.mutable.ListBuffer

/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
* These fall into five categories
*
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
* 2. Closures implementing non-trait classes
* 3. Closures implementing classes that inherit from a class other than Object
* (a lambda cannot not be a run-time subtype of such a class)
* (a lambda cannot be a run-time subtype of such a class)
* 4. Closures that implement traits which run initialization code.
* 5. Closures that get synthesized abstract methods in the transformation pipeline. These methods can be
* (1) superaccessors, (2) outer references, (3) accessors for fields.
Expand Down Expand Up @@ -59,7 +60,7 @@ class ExpandSAMs extends MiniPhase:
// A SAM type is allowed to have type aliases refinements (see
// SAMType#samParent) which must be converted into type members if
// the closure is desugared into a class.
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
val refinements = ListBuffer.empty[(TypeName, TypeAlias)]
def collectAndStripRefinements(tp: Type): Type = tp match
case RefinedType(parent, name, info: TypeAlias) =>
val res = collectAndStripRefinements(parent)
Expand All @@ -81,34 +82,40 @@ class ExpandSAMs extends MiniPhase:
tree
}

/** A partial function literal:
/** A pattern-matching anonymous function:
*
* ```
* val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En }
* ```
* or
* ```
* x => e(x) { case C1 => E1; ...; case Cn => En }
* ```
* where the expression `e(x)` may be trivially `x`
*
* which desugars to:
*
* ```
* val x: PartialFunction[A, B] = {
* def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En }
* def $anonfun(x: A): B = e(x) match { case C1 => E1; ...; case Cn => En }
* closure($anonfun: PartialFunction[A, B])
* }
* ```
* where the expression `e(x)` defaults to `x` for a simple block of cases
*
* is expanded to an anonymous class:
*
* ```
* val x: PartialFunction[A, B] = {
* class $anon extends AbstractPartialFunction[A, B] {
* final def isDefinedAt(x: A): Boolean = x match {
* final def isDefinedAt(x: A): Boolean = e(x) match {
* case C1 => true
* ...
* case Cn => true
* case _ => false
* }
*
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match {
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = e(x) match {
* case C1 => E1
* ...
* case Cn => En
Expand All @@ -120,7 +127,7 @@ class ExpandSAMs extends MiniPhase:
* }
* ```
*/
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree =
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree: @unchecked

// The right hand side from which to construct the partial function. This is always a Match.
Expand All @@ -146,7 +153,7 @@ class ExpandSAMs extends MiniPhase:
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
defn.SerializableType)

AnonClass(anonSym.owner, parents, tree.span) { pfSym =>
AnonClass(anonSym.owner, parents, tree.span): pfSym =>
def overrideSym(sym: Symbol) = sym.copy(
owner = pfSym,
flags = Synthetic | Method | Final | Override,
Expand All @@ -155,7 +162,8 @@ class ExpandSAMs extends MiniPhase:
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)

def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
def translateMatch(owner: Symbol)(pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) =
val tree: Match = pfRHS
val selector = tree.selector
val cases1 = if cases.exists(isDefaultCase) then cases
else
Expand All @@ -165,31 +173,27 @@ class ExpandSAMs extends MiniPhase:
cases :+ defaultCase
cpy.Match(tree)(selector, cases1)
.subst(param.symbol :: Nil, pfParam :: Nil)
// Needed because a partial function can be written as:
// Needed because a partial function can be written as:
// param => param match { case "foo" if foo(param) => param }
// And we need to update all references to 'param'
}
.changeOwner(anonSym, owner)

def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) =
val tru = Literal(Constant(true))
def translateCase(cdef: CaseDef) =
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru)
val paramRef = paramRefss.head.head
val defaultValue = Literal(Constant(false))
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
}
translateMatch(isDefinedAtFn)(paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)

def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) =
val List(paramRef, defaultRef) = paramRefss(1)
def translateCase(cdef: CaseDef) =
cdef.changeOwner(anonSym, applyOrElseFn)
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
}
translateMatch(applyOrElseFn)(paramRef.symbol, pfRHS.cases, defaultValue)

val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
val isDefinedAtDef = transformFollowingDeep:
DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn)))
val applyOrElseDef = transformFollowingDeep:
DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn)))
List(isDefinedAtDef, applyOrElseDef)
}
}
end toPartialFunction
end ExpandSAMs
18 changes: 17 additions & 1 deletion docs/_spec/08-pattern-matching.md
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ new scala.PartialFunction[´S´, ´T´] {
def apply(´x´: ´S´): ´T´ = x match {
case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´
}
def isDefinedAt(´x´: ´S´): Boolean = {
def isDefinedAt(´x´: ´S´): Boolean = x match {
case ´p_1´ => true ... case ´p_n´ => true
case _ => false
}
Expand All @@ -626,6 +626,22 @@ new scala.PartialFunction[´S´, ´T´] {
Here, ´x´ is a fresh name and ´T´ is the least upper bound of the types of all ´b_i´.
The final default case in the `isDefinedAt` method is omitted if one of the patterns ´p_1, ..., p_n´ is already a variable or wildcard pattern.

As a convenience, the partial function may be written using function literal notation:

```scala
(´x: S´) => e(´x´) match {
case ´p_1´ => ´b_1´ ... case ´p_n´ => ´b_n´
}
```
where the selector expression is used for matches in the expansion.
The body of the function must consist solely of the match expression.

This syntax permits annotating the selector:

```scala
(´x: S´) => (e(´x´): @unchecked) match { ... }
```

###### Example
Here's an example which uses `foldLeft` to compute the scalar product of two vectors:

Expand Down
5 changes: 5 additions & 0 deletions tests/pos/i23025.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

class A {
def f: PartialFunction[Int, Int] =
a => { (try a catch { case e : Throwable => throw e}) match { case n => n } }
}
15 changes: 15 additions & 0 deletions tests/pos/i23054.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

object Bug:

def m0(f: PartialFunction[Char, Unit]): Unit = ()

def m1(): Unit =
m0: x =>
"abc".filter(_ == x) match
case _ => ()

def m2(): Unit =
m0: x =>
x match
case _ => ()

16 changes: 16 additions & 0 deletions tests/pos/i23310.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

object Example {
val pf: PartialFunction[Unit, Unit] = s => (s match {
case a => a
}) match {
case a => ()
}
}

object ExampleB:
def test =
List(42).collect:
_.match
case x => x
.match
case y => y + 27
Loading