diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index d6c6dd9ec2c0..d9f277ebfc6d 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -10,11 +10,16 @@ import config.SourceVersion import config.Printers.capt import util.Property.Key import tpd.* +import StdNames.nme import config.Feature +import collection.mutable private val Captures: Key[CaptureSet] = Key() private val BoxedType: Key[BoxedTypeCache] = Key() +/** Attachment key for the nesting level cache */ +val ccState: Key[CCState] = Key() + /** Switch whether unpickled function types and byname types should be mapped to * impure types. With the new gradual typing using Fluid capture sets, this should * be no longer needed. Also, it has bad interactions with pickling tests. @@ -32,6 +37,40 @@ def allowUniversalInBoxed(using Context) = /** An exception thrown if a @retains argument is not syntactically a CaptureRef */ class IllegalCaptureRef(tpe: Type) extends Exception +class CCState: + val nestingLevels: mutable.HashMap[Symbol, Int] = new mutable.HashMap + val localRoots: mutable.HashMap[Symbol, CaptureRef] = new mutable.HashMap + var levelError: Option[(CaptureRef, CaptureSet)] = None + +class mapRoots(lowner: Symbol)(using Context) extends BiTypeMap: + thisMap => + + def apply(t: Type): Type = t.dealiasKeepAnnots match + case t1: CaptureRef if t1.isGenericRootCapability => + assert(lowner.exists, "cannot map global root") + lowner.localRoot + case _: MethodOrPoly => + t + case t1 if defn.isFunctionType(t1) => + t + case t1 => + val t2 = mapOver(t1) + if t2 ne t1 then t2 else t + + def inverse = new BiTypeMap: + def apply(t: Type): Type = t.dealiasKeepAnnots match + case t1: CaptureRef if t1.localRootOwner == lowner => + defn.captureRoot.termRef + case _: MethodOrPoly => + t + case t1 if defn.isFunctionType(t1) => + t + case t1 => + val t2 = mapOver(t1) + if t2 ne t1 then t2 else t + def inverse = thisMap +end mapRoots + extension (tree: Tree) /** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */ @@ -164,7 +203,7 @@ extension (tp: Type) * a by name parameter type, turning the latter into an impure by name parameter type. */ def adaptByNameArgUnderPureFuns(using Context): Type = - if Feature.pureFunsEnabledSomewhere then + if adaptUnpickledFunctionTypes && Feature.pureFunsEnabledSomewhere then AnnotatedType(tp, CaptureAnnotation(CaptureSet.universal, boxed = false)(defn.RetainsByNameAnnot)) else @@ -199,6 +238,13 @@ extension (tp: Type) case _ => false + def capturedLocalRoot(using Context): Symbol = + tp.captureSet.elems.toList + .filter(_.isLocalRootCapability) + .map(_.termSymbol) + .maxByOption(_.ccNestingLevel) + .getOrElse(NoSymbol) + extension (cls: ClassSymbol) def pureBaseClass(using Context): Option[Symbol] = @@ -253,6 +299,46 @@ extension (sym: Symbol) && sym != defn.Caps_unsafeBox && sym != defn.Caps_unsafeUnbox + /** The owner of the current level. Qualifying owners are + * - methods other than constructors + * - classes, if they are not staticOwners + * - _root_ + */ + def levelOwner(using Context): Symbol = + if !sym.exists || sym.isRoot || sym.isStaticOwner then defn.RootClass + else if sym.isClass || sym.is(Method) && !sym.isConstructor then sym + else sym.owner.levelOwner + + /** The nesting level of `sym` for the purposes of `cc`, + * -1 for NoSymbol + */ + def ccNestingLevel(using Context): Int = + if sym.exists then + val lowner = sym.levelOwner + val cache = ctx.property(ccState).get.nestingLevels + cache.getOrElseUpdate(lowner, + if lowner.isRoot then 0 else lowner.owner.ccNestingLevel + 1) + else -1 + + /** Optionally, the nesting level of `sym` for the purposes of `cc`, provided + * a capture checker is running. + */ + def ccNestingLevelOpt(using Context): Option[Int] = + if ctx.property(ccState).isDefined then + Some(ccNestingLevel) + else None + + def localRoot(using Context): CaptureRef = + assert(sym.exists && sym.levelOwner == sym, sym) + ctx.property(ccState).get.localRoots.getOrElseUpdate(sym, + newSymbol(sym, nme.LOCAL_CAPTURE_ROOT, Synthetic, defn.AnyType, nestingLevel = sym.ccNestingLevel).termRef) + + def maxNested(other: Symbol)(using Context): Symbol = + if sym.ccNestingLevel < other.ccNestingLevel then other else sym + + def minNested(other: Symbol)(using Context): Symbol = + if sym.ccNestingLevel > other.ccNestingLevel then other else sym + extension (tp: AnnotatedType) /** Is this a boxed capturing type? */ def isBoxed(using Context): Boolean = tp.annot match diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 3f2beaa3ff55..f0204f893f1d 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -12,7 +12,8 @@ import annotation.internal.sharable import reporting.trace import printing.{Showable, Printer} import printing.Texts.* -import util.{SimpleIdentitySet, Property} +import util.{SimpleIdentitySet, Property, optional}, optional.{break, ?} +import typer.ErrorReporting.Addenda import util.common.alwaysTrue import scala.collection.mutable import config.Config.ccAllowUnsoundMaps @@ -55,6 +56,11 @@ sealed abstract class CaptureSet extends Showable: */ def isAlwaysEmpty: Boolean + /** The level owner in which the set is defined. Sets can only take + * elements with nesting level up to the cc-nestinglevel of owner. + */ + def owner: Symbol + /** Is this capture set definitely non-empty? */ final def isNotEmpty: Boolean = !elems.isEmpty @@ -113,20 +119,31 @@ sealed abstract class CaptureSet extends Showable: else addNewElems(elem.singletonCaptureSet.elems, origin) /* x subsumes y if x is the same as y, or x is a this reference and y refers to a field of x */ - extension (x: CaptureRef) private def subsumes(y: CaptureRef) = - (x eq y) - || y.match - case y: TermRef => y.prefix eq x - case _ => false + extension (x: CaptureRef)(using Context) + private def subsumes(y: CaptureRef) = + (x eq y) + || x.isGenericRootCapability + || y.match + case y: TermRef => (y.prefix eq x) || x.isRootIncluding(y) + case _ => false + + private def isRootIncluding(y: CaptureRef) = + x.isLocalRootCapability && y.isLocalRootCapability + && x.termSymbol.nestingLevel >= y.termSymbol.nestingLevel + end extension /** {x} <:< this where <:< is subcapturing, but treating all variables * as frozen. */ def accountsFor(x: CaptureRef)(using Context): Boolean = - reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) { - elems.exists(_.subsumes(x)) - || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK - } + if comparer.isInstanceOf[ExplainingTypeComparer] then // !!! DEBUG + reporting.trace.force(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true): + elems.exists(_.subsumes(x)) + || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK + else + reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true): + elems.exists(_.subsumes(x)) + || !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK /** A more optimistic version of accountsFor, which does not take variable supersets * of the `x` reference into account. A set might account for `x` if it accounts @@ -191,7 +208,8 @@ sealed abstract class CaptureSet extends Showable: if this.subCaptures(that, frozen = true).isOK then that else if that.subCaptures(this, frozen = true).isOK then this else if this.isConst && that.isConst then Const(this.elems ++ that.elems) - else Var(this.elems ++ that.elems).addAsDependentTo(this).addAsDependentTo(that) + else Var(this.owner.maxNested(that.owner), this.elems ++ that.elems) + .addAsDependentTo(this).addAsDependentTo(that) /** The smallest superset (via <:<) of this capture set that also contains `ref`. */ @@ -276,7 +294,9 @@ sealed abstract class CaptureSet extends Showable: if isUniversal then handler() this - /** Invoke handler on the elements to check wellformedness of the capture set */ + /** Invoke handler on the elements to ensure wellformedness of the capture set. + * The handler might add additional elements to the capture set. + */ def ensureWellformed(handler: List[CaptureRef] => Context ?=> Unit)(using Context): this.type = handler(elems.toList) this @@ -308,7 +328,7 @@ sealed abstract class CaptureSet extends Showable: Annotation(CaptureAnnotation(this, boxed = false)(cls).tree) override def toText(printer: Printer): Text = - Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}") ~~ description + printer.toTextCaptureSet(this) object CaptureSet: type Refs = SimpleIdentitySet[CaptureRef] @@ -353,6 +373,8 @@ object CaptureSet: def withDescription(description: String): Const = Const(elems, description) + def owner = NoSymbol + override def toString = elems.toString end Const @@ -371,16 +393,23 @@ object CaptureSet: end Fluid /** The subclass of captureset variables with given initial elements */ - class Var(initialElems: Refs = emptySet) extends CaptureSet: + class Var(directOwner: Symbol, initialElems: Refs = emptySet)(using @constructorOnly ictx: Context) extends CaptureSet: /** A unique identification number for diagnostics */ val id = varId += 1 varId + override val owner = directOwner.levelOwner + /** A variable is solved if it is aproximated to a from-then-on constant set. */ private var isSolved: Boolean = false + private var ownLevelCache = -1 + private def ownLevel(using Context) = + if ownLevelCache == -1 then ownLevelCache = owner.ccNestingLevel + ownLevelCache + /** The elements currently known to be in the set */ var elems: Refs = initialElems @@ -400,6 +429,8 @@ object CaptureSet: var description: String = "" + private var triedElem: Option[CaptureRef] = None + /** Record current elements in given VarState provided it does not yet * contain an entry for this variable. */ @@ -425,7 +456,10 @@ object CaptureSet: deps = state.deps(this) def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = - if !isConst && recordElemsState() then + if isConst || !recordElemsState() then + CompareResult.fail(this) // fail if variable is solved or given VarState is frozen + else if levelsOK(newElems) then + //assert(id != 2, newElems) elems ++= newElems if isUniversal then rootAddedHandler() newElemAddedHandler(newElems.toList) @@ -433,8 +467,36 @@ object CaptureSet: (CompareResult.OK /: deps) { (r, dep) => r.andAlso(dep.tryInclude(newElems, this)) } - else // fail if variable is solved or given VarState is frozen - CompareResult.fail(this) + else + val res = widenCaptures(newElems) match + case Some(newElems1) => tryInclude(newElems1, origin) + case None => CompareResult.fail(this) + if !res.isOK then recordLevelError() + res + + private def recordLevelError()(using Context): Unit = + for elem <- triedElem do + ctx.property(ccState).get.levelError = Some((elem, this)) + + private def levelsOK(elems: Refs)(using Context): Boolean = + !elems.exists(_.ccNestingLevel > ownLevel) + + private def widenCaptures(elems: Refs)(using Context): Option[Refs] = + val res = optional: + (SimpleIdentitySet[CaptureRef]() /: elems): (acc, elem) => + if elem.ccNestingLevel <= ownLevel then acc + elem + else if elem.isRootCapability then break() + else + val saved = triedElem + triedElem = triedElem.orElse(Some(elem)) + val res = acc ++ widenCaptures(elem.captureSetOfInfo.elems).? + triedElem = saved // reset only in case of success, leave as is on error + res + def resStr = res match + case Some(refs) => i"${refs.toList}" + case None => "FAIL" + capt.println(i"widen captures ${elems.toList} for $this at $owner = $resStr") + res def addDependent(cs: CaptureSet)(using Context, VarState): CompareResult = if (cs eq this) || cs.isUniversal || isConst then @@ -519,8 +581,8 @@ object CaptureSet: end Var /** A variable that is derived from some other variable via a map or filter. */ - abstract class DerivedVar(initialElems: Refs)(using @constructorOnly ctx: Context) - extends Var(initialElems): + abstract class DerivedVar(owner: Symbol, initialElems: Refs)(using @constructorOnly ctx: Context) + extends Var(owner, initialElems): // For debugging: A trace where a set was created. Note that logically it would make more // sense to place this variable in Mapped, but that runs afoul of the initializatuon checker. @@ -546,7 +608,7 @@ object CaptureSet: */ class Mapped private[CaptureSet] (val source: Var, tm: TypeMap, variance: Int, initial: CaptureSet)(using @constructorOnly ctx: Context) - extends DerivedVar(initial.elems): + extends DerivedVar(source.owner, initial.elems): addAsDependentTo(initial) // initial mappings could change by propagation private def mapIsIdempotent = tm.isInstanceOf[IdempotentCaptRefMap] @@ -612,7 +674,7 @@ object CaptureSet: */ final class BiMapped private[CaptureSet] (val source: Var, bimap: BiTypeMap, initialElems: Refs)(using @constructorOnly ctx: Context) - extends DerivedVar(initialElems): + extends DerivedVar(source.owner, initialElems): override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = if origin eq source then @@ -633,7 +695,7 @@ object CaptureSet: */ override def computeApprox(origin: CaptureSet)(using Context): CaptureSet = val supApprox = super.computeApprox(this) - if source eq origin then supApprox.map(bimap.inverseTypeMap) + if source eq origin then supApprox.map(bimap.inverse) else source.upperApprox(this).map(bimap) ** supApprox override def toString = s"BiMapped$id($source, elems = $elems)" @@ -642,7 +704,7 @@ object CaptureSet: /** A variable with elements given at any time as { x <- source.elems | p(x) } */ class Filtered private[CaptureSet] (val source: Var, p: Context ?=> CaptureRef => Boolean)(using @constructorOnly ctx: Context) - extends DerivedVar(source.elems.filter(p)): + extends DerivedVar(source.owner, source.elems.filter(p)): override def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult = val filtered = newElems.filter(p) @@ -673,7 +735,7 @@ object CaptureSet: extends Filtered(source, !other.accountsFor(_)) class Intersected(cs1: CaptureSet, cs2: CaptureSet)(using Context) - extends Var(elemIntersection(cs1, cs2)): + extends Var(cs1.owner.minNested(cs2.owner), elemIntersection(cs1, cs2)): addAsDependentTo(cs1) addAsDependentTo(cs2) deps += cs1 @@ -934,4 +996,17 @@ object CaptureSet: println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %") } else op + + def levelErrors: Addenda = new Addenda: + override def toAdd(using Context): List[String] = + for + state <- ctx.property(ccState).toList + (ref, cs) <- state.levelError + yield + val level = ref.ccNestingLevel + i""" + | + |Note that reference ${ref}, defined at level $level + |cannot be included in outer capture set $cs, defined at level ${cs.owner.nestingLevel} in ${cs.owner}""" + end CaptureSet diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index b6b5d569677c..fe13bd18b213 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -12,6 +12,7 @@ import ast.{tpd, untpd, Trees} import Trees.* import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker} import typer.Checking.{checkBounds, checkAppliedTypesIn} +import typer.ErrorReporting.Addenda import util.{SimpleIdentitySet, EqHashMap, SrcPos, Property} import transform.SymUtils.* import transform.{Recheck, PreRecheck} @@ -73,12 +74,11 @@ object CheckCaptures: /** Similar normal substParams, but this is an approximating type map that * maps parameters in contravariant capture sets to the empty set. - * TODO: check what happens with non-variant. */ final class SubstParamsMap(from: BindingType, to: List[Type])(using Context) extends ApproximatingTypeMap, IdempotentCaptRefMap: /** This SubstParamsMap is exact if `to` only contains `CaptureRef`s. */ - private val isExactSubstitution: Boolean = to.forall(_.isInstanceOf[CaptureRef]) + private val isExactSubstitution: Boolean = to.forall(_.isTrackableRef) /** As long as this substitution is exact, there is no need to create `Range`s when mapping invariant positions. */ override protected def needsRangeIfInvariant(refs: CaptureSet): Boolean = !isExactSubstitution @@ -96,6 +96,39 @@ object CheckCaptures: mapOver(tp) end SubstParamsMap + final class SubstParamsBiMap(from: LambdaType, to: List[Type])(using Context) + extends BiTypeMap: + thisMap => + + def apply(tp: Type): Type = tp match + case tp: ParamRef => + if tp.binder == from then to(tp.paramNum) else tp + case tp: NamedType => + if tp.prefix `eq` NoPrefix then tp + else tp.derivedSelect(apply(tp.prefix)) + case _: ThisType => + tp + case _ => + mapOver(tp) + + lazy val inverse = new BiTypeMap: + def apply(tp: Type): Type = tp match + case tp: NamedType => + var idx = 0 + var to1 = to + while idx < to.length && (tp ne to(idx)) do + idx += 1 + to1 = to1.tail + if idx < to.length then from.paramRefs(idx) + else if tp.prefix `eq` NoPrefix then tp + else tp.derivedSelect(apply(tp.prefix)) + case _: ThisType => + tp + case _ => + mapOver(tp) + def inverse = thisMap + end SubstParamsBiMap + /** Check that a @retains annotation only mentions references that can be tracked. * This check is performed at Typer. */ @@ -107,7 +140,7 @@ object CheckCaptures: for elem <- retainedElems(ann) do elem.tpe match case ref: CaptureRef => - if !ref.canBeTracked then + if !ref.isTrackableRef then report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos) case tpe => report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos) @@ -228,10 +261,12 @@ class CheckCaptures extends Recheck, SymTransformer: def header = if cs1.elems.size == 1 then i"reference ${cs1.elems.toList}%, % is not" else i"references $cs1 are not all" - report.error(em"$header included in allowed capture set ${res.blocking}", pos) + def toAdd: String = CaptureSet.levelErrors.toAdd.mkString + report.error(em"$header included in allowed capture set ${res.blocking}$toAdd", pos) /** The current environment */ - private var curEnv: Env = Env(NoSymbol, EnvKind.Regular, CaptureSet.empty, null) + private var curEnv: Env = inContext(ictx): + Env(defn.RootClass, EnvKind.Regular, CaptureSet.empty, null) private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap() @@ -240,7 +275,8 @@ class CheckCaptures extends Recheck, SymTransformer: */ def capturedVars(sym: Symbol)(using Context) = myCapturedVars.getOrElseUpdate(sym, - if sym.ownersIterator.exists(_.isTerm) then CaptureSet.Var() + if sym.ownersIterator.exists(_.isTerm) then + CaptureSet.Var(if sym.isConstructor then sym.owner.owner else sym.owner) else CaptureSet.empty) /** For all nested environments up to `limit` or a closed environment perform `op`, @@ -408,10 +444,16 @@ class CheckCaptures extends Recheck, SymTransformer: else if meth == defn.Caps_unsafeUnbox then mapArgUsing(_.forceBoxStatus(false)) else if meth == defn.Caps_unsafeBoxFunArg then - mapArgUsing: + def forceBox(tp: Type): Type = tp match case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) => defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual) - + case tp @ RefinedType(parent, rname, rinfo: MethodType) => + tp.derivedRefinedType(parent, rname, + rinfo.derivedLambdaType( + paramInfos = rinfo.paramInfos.map(_.forceBoxStatus(true)))) + case tp @ CapturingType(parent, refs) => + tp.derivedCapturingType(forceBox(parent), refs) + mapArgUsing(forceBox) else super.recheckApply(tree, pt) match case appType @ CapturingType(appType1, refs) => @@ -431,6 +473,10 @@ class CheckCaptures extends Recheck, SymTransformer: case appType => appType end recheckApply + private def isDistinct(xs: List[Type]): Boolean = xs match + case x :: xs1 => xs1.isEmpty || !xs1.contains(x) && isDistinct(xs1) + case Nil => true + /** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`. * This means: * - Instantiate result type with actual arguments @@ -438,11 +484,19 @@ class CheckCaptures extends Recheck, SymTransformer: * - remember types of arguments corresponding to tracked * parameters in refinements. * - add capture set of instantiated class to capture set of result type. + * If all argument types are mutually disfferent trackable capture references, use a BiTypeMap, + * since that is more precise. Otherwise use a normal idempotent map, which might lose information + * in the case where the result type contains captureset variables that are further + * constrained afterwards. */ override def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = val ownType = - if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType) - else mt.resType + if !mt.isResultDependent then + mt.resType + else if argTypes.forall(_.isTrackableRef) && isDistinct(argTypes) then + SubstParamsBiMap(mt, argTypes)(mt.resType) + else + SubstParamsMap(mt, argTypes)(mt.resType) if sym.isConstructor then val cls = sym.owner.asClass @@ -485,63 +539,28 @@ class CheckCaptures extends Recheck, SymTransformer: else ownType end instantiate - override def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + override def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean)(using Context): Type = val cs = capturedVars(tree.meth.symbol) capt.println(i"typing closure $tree with cvs $cs") - super.recheckClosure(tree, pt).capturing(cs) - .showing(i"rechecked $tree / $pt = $result", capt) - - /** Additionally to normal processing, update types of closures if the expected type - * is a function with only pure parameters. In that case, make the anonymous function - * also have the same parameters as the prototype. - * TODO: Develop a clearer rationale for this. - * TODO: Can we generalize this to arbitrary parameters? - * Currently some tests fail if we do this. (e.g. neg.../stackAlloc.scala, others) - */ - override def recheckBlock(block: Block, pt: Type)(using Context): Type = - block match - case closureDef(mdef) => - pt.dealias match - case defn.FunctionOf(ptformals, _, _) - if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) => - // Redo setup of the anonymous function so that formal parameters don't - // get capture sets. This is important to avoid false widenings to `cap` - // when taking the base type of the actual closures's dependent function - // type so that it conforms to the expected non-dependent function type. - // See withLogFile.scala for a test case. - val meth = mdef.symbol - // First, undo the previous setup which installed a completer for `meth`. - atPhase(preRecheckPhase.prev)(meth.denot.copySymDenotation()) - .installAfter(preRecheckPhase) - - // Next, update all parameter symbols to match expected formals - meth.paramSymss.head.lazyZip(ptformals).foreach: (psym, pformal) => - psym.updateInfoBetween(preRecheckPhase, thisPhase, pformal.mapExprType) - - // Next, update types of parameter ValDefs - mdef.paramss.head.lazyZip(ptformals).foreach: (param, pformal) => - val ValDef(_, tpt, _) = param: @unchecked - tpt.rememberTypeAlways(pformal) - - // Next, install a new completer reflecting the new parameters for the anonymous method - val mt = meth.info.asInstanceOf[MethodType] - val completer = new LazyType: - def complete(denot: SymDenotation)(using Context) = - denot.info = mt.companion(ptformals, mdef.tpt.knownType) - .showing(i"simplify info of $meth to $result", capt) - recheckDef(mdef, meth) - meth.updateInfoBetween(preRecheckPhase, thisPhase, completer) - case _ => - mdef.rhs match - case rhs @ closure(_, _, _) => - // In a curried closure `x => y => e` don't leak capabilities retained by - // the second closure `y => e` into the first one. This is an approximation - // of the CC rule which says that a closure contributes captures to its - // environment only if a let-bound reference to the closure is used. - mdef.rhs.putAttachment(ClosureBodyValue, ()) - case _ => + super.recheckClosure(tree, pt, forceDependent).capturing(cs) + .showing(i"rechecked closure $tree / $pt = $result", capt) + + override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type = + mdef.rhs match + case rhs @ closure(_, _, _) => + // In a curried closure `x => y => e` don't leak capabilities retained by + // the second closure `y => e` into the first one. This is an approximation + // of the CC rule which says that a closure contributes captures to its + // environment only if a let-bound reference to the closure is used. + mdef.rhs.putAttachment(ClosureBodyValue, ()) case _ => - super.recheckBlock(block, pt) + + // Constrain closure's parameters and result from the expected type before + // rechecking the body. + val res = recheckClosure(expr, pt, forceDependent = true) + recheckDef(mdef, mdef.symbol) + res + end recheckClosureBlock override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit = try @@ -640,9 +659,9 @@ class CheckCaptures extends Recheck, SymTransformer: val saved = curEnv tree match case _: RefTree | closureDef(_) if pt.isBoxedCapturing => - curEnv = Env(curEnv.owner, EnvKind.Boxed, CaptureSet.Var(), curEnv) + curEnv = Env(curEnv.owner, EnvKind.Boxed, CaptureSet.Var(curEnv.owner), curEnv) case _ if tree.hasAttachment(ClosureBodyValue) => - curEnv = Env(curEnv.owner, EnvKind.ClosureResult, CaptureSet.Var(), curEnv) + curEnv = Env(curEnv.owner, EnvKind.ClosureResult, CaptureSet.Var(curEnv.owner), curEnv) case _ => val res = try super.recheck(tree, pt) @@ -659,13 +678,10 @@ class CheckCaptures extends Recheck, SymTransformer: * of simulated boxing and unboxing. */ override def recheckFinish(tpe: Type, tree: Tree, pt: Type)(using Context): Type = - val typeToCheck = tree match - case _: Ident | _: Select | _: Apply | _: TypeApply if tree.symbol.unboxesResult => - tpe - case _: Try => - tpe - case _ => - NoType + def needsUniversalCheck = tree match + case _: RefTree | _: Apply | _: TypeApply => tree.symbol.unboxesResult + case _: Try => true + case _ => false def checkNotUniversal(tp: Type): Unit = tp.widenDealias match case wtp @ CapturingType(parent, refs) => refs.disallowRootCapability { () => @@ -676,8 +692,10 @@ class CheckCaptures extends Recheck, SymTransformer: } checkNotUniversal(parent) case _ => - if !allowUniversalInBoxed then checkNotUniversal(typeToCheck) + if !allowUniversalInBoxed && needsUniversalCheck then + checkNotUniversal(tpe) super.recheckFinish(tpe, tree, pt) + end recheckFinish // ------------------ Adaptation ------------------------------------- // @@ -690,11 +708,11 @@ class CheckCaptures extends Recheck, SymTransformer: // - Adapt box status and environment capture sets by simulating box/unbox operations. /** Massage `actual` and `expected` types using the methods below before checking conformance */ - override def checkConformsExpr(actual: Type, expected: Type, tree: Tree)(using Context): Unit = + override def checkConformsExpr(actual: Type, expected: Type, tree: Tree, addenda: Addenda)(using Context): Unit = val expected1 = alignDependentFunction(addOuterRefs(expected, actual), actual.stripCapturing) val actual1 = adaptBoxed(actual, expected1, tree.srcPos) //println(i"check conforms $actual1 <<< $expected1") - super.checkConformsExpr(actual1, expected1, tree) + super.checkConformsExpr(actual1, expected1, tree, addenda ++ CaptureSet.levelErrors) private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type = MethodType.companion(isContextual = isContextual)(args, resultType) @@ -767,6 +785,12 @@ class CheckCaptures extends Recheck, SymTransformer: */ def adaptBoxed(actual: Type, expected: Type, pos: SrcPos, alwaysConst: Boolean = false)(using Context): Type = + inline def inNestedEnv[T](boxed: Boolean)(op: => T): T = + val saved = curEnv + curEnv = Env(curEnv.owner, EnvKind.NestedInOwner, CaptureSet.Var(curEnv.owner), if boxed then null else curEnv) + try op + finally curEnv = saved + /** Adapt function type `actual`, which is `aargs -> ares` (possibly with dependencies) * to `expected` type. * It returns the adapted type along with a capture set consisting of the references @@ -776,10 +800,7 @@ class CheckCaptures extends Recheck, SymTransformer: def adaptFun(actual: Type, aargs: List[Type], ares: Type, expected: Type, covariant: Boolean, boxed: Boolean, reconstruct: (List[Type], Type) => Type): (Type, CaptureSet) = - val saved = curEnv - curEnv = Env(curEnv.owner, EnvKind.NestedInOwner, CaptureSet.Var(), if boxed then null else curEnv) - - try + inNestedEnv(boxed): val (eargs, eres) = expected.dealias.stripCapturing match case defn.FunctionOf(eargs, eres, _) => (eargs, eres) case expected: MethodType => (expected.paramInfos, expected.resType) @@ -793,8 +814,6 @@ class CheckCaptures extends Recheck, SymTransformer: else reconstruct(aargs1, ares1) (resTp, curEnv.captured) - finally - curEnv = saved /** Adapt type function type `actual` to the expected type. * @see [[adaptFun]] @@ -803,10 +822,7 @@ class CheckCaptures extends Recheck, SymTransformer: actual: Type, ares: Type, expected: Type, covariant: Boolean, boxed: Boolean, reconstruct: Type => Type): (Type, CaptureSet) = - val saved = curEnv - curEnv = Env(curEnv.owner, EnvKind.NestedInOwner, CaptureSet.Var(), if boxed then null else curEnv) - - try + inNestedEnv(boxed): val eres = expected.dealias.stripCapturing match case RefinedType(_, _, rinfo: PolyType) => rinfo.resType case expected: PolyType => expected.resType @@ -819,8 +835,6 @@ class CheckCaptures extends Recheck, SymTransformer: else reconstruct(ares1) (resTp, curEnv.captured) - finally - curEnv = saved end adaptTypeFun def adaptInfo(actual: Type, expected: Type, covariant: Boolean): String = @@ -962,16 +976,16 @@ class CheckCaptures extends Recheck, SymTransformer: traverseChildren(t) override def checkUnit(unit: CompilationUnit)(using Context): Unit = - Setup(preRecheckPhase, thisPhase, recheckDef)(ctx.compilationUnit.tpdTree) - //println(i"SETUP:\n${Recheck.addRecheckedTypes.transform(ctx.compilationUnit.tpdTree)}") - withCaptureSetsExplained { - super.checkUnit(unit) - checkOverrides.traverse(unit.tpdTree) - checkSelfTypes(unit.tpdTree) - postCheck(unit.tpdTree) - if ctx.settings.YccDebug.value then - show(unit.tpdTree) // this does not print tree, but makes its variables visible for dependency printing - } + inContext(ctx.withProperty(ccState, Some(new CCState))): + Setup(preRecheckPhase, thisPhase, this)(ctx.compilationUnit.tpdTree) + //println(i"SETUP:\n${Recheck.addRecheckedTypes.transform(ctx.compilationUnit.tpdTree)}") + withCaptureSetsExplained: + super.checkUnit(unit) + checkOverrides.traverse(unit.tpdTree) + checkSelfTypes(unit.tpdTree) + postCheck(unit.tpdTree) + if ctx.settings.YccDebug.value then + show(unit.tpdTree) // this does not print tree, but makes its variables visible for dependency printing /** Check that self types of subclasses conform to self types of super classes. * (See comment below how this is achieved). The check assumes that classes @@ -1032,9 +1046,9 @@ class CheckCaptures extends Recheck, SymTransformer: * that this type parameter can't see. * For example, when capture checking the following expression: * - * def usingLogFile[T](op: (f: {cap} File) => T): T = ... + * def usingLogFile[T](op: File^ => T): T = ... * - * usingLogFile[box ?1 () -> Unit] { (f: {cap} File) => () => { f.write(0) } } + * usingLogFile[box ?1 () -> Unit] { (f: File^) => () => { f.write(0) } } * * We may propagate `f` into ?1, making ?1 ill-formed. * This also causes soundness issues, since `f` in ?1 should be widened to `cap`, @@ -1046,34 +1060,26 @@ class CheckCaptures extends Recheck, SymTransformer: */ private def healTypeParam(tree: Tree)(using Context): Unit = val checker = new TypeTraverser: + private var allowed: SimpleIdentitySet[TermParamRef] = SimpleIdentitySet.empty + private def isAllowed(ref: CaptureRef): Boolean = ref match case ref: TermParamRef => allowed.contains(ref) case _ => true - // Widen the given term parameter refs x₁ : C₁ S₁ , ⋯ , xₙ : Cₙ Sₙ to their capture sets C₁ , ⋯ , Cₙ. - // - // If in these capture sets there are any capture references that are term parameter references we should avoid, - // we will widen them recursively. - private def widenParamRefs(refs: List[TermParamRef]): List[CaptureSet] = - @scala.annotation.tailrec - def recur(todos: List[TermParamRef], acc: List[CaptureSet]): List[CaptureSet] = - todos match - case Nil => acc - case ref :: rem => - val cs = ref.captureSetOfInfo - val nextAcc = cs.filter(isAllowed(_)) :: acc - val nextRem: List[TermParamRef] = (cs.elems.toList.filter(!isAllowed(_)) ++ rem).asInstanceOf - recur(nextRem, nextAcc) - recur(refs, Nil) - private def healCaptureSet(cs: CaptureSet): Unit = - def avoidance(elems: List[CaptureRef])(using Context): Unit = - val toInclude = widenParamRefs(elems.filter(!isAllowed(_)).asInstanceOf) - //println(i"HEAL $cs by widening to $toInclude") - toInclude.foreach(checkSubset(_, cs, tree.srcPos)) - cs.ensureWellformed(avoidance) - - private var allowed: SimpleIdentitySet[TermParamRef] = SimpleIdentitySet.empty + cs.ensureWellformed: elems => + ctx ?=> + var seen = new util.HashSet[CaptureRef] + def recur(elems: List[CaptureRef]): Unit = + for ref <- elems do + if !isAllowed(ref) && !seen.contains(ref) then + seen += ref + val widened = ref.captureSetOfInfo + val added = widened.filter(isAllowed(_)) + capt.println(i"heal $ref in $cs by widening to $added") + checkSubset(added, cs, tree.srcPos) + recur(widened.elems.toList) + recur(elems) def traverse(tp: Type) = tp match @@ -1132,6 +1138,8 @@ class CheckCaptures extends Recheck, SymTransformer: || // non-local symbols cannot have inferred types since external capture types are not inferred isLocal // local symbols still need explicit types if && !sym.owner.is(Trait) // they are defined in a trait, since we do OverridingPairs checking before capture inference + || + sym.allOverriddenSymbols.nonEmpty def isNotPureThis(ref: CaptureRef) = ref match { case ref: ThisType => !ref.cls.isPureClass case _ => true diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 4c32c2908635..4c899c579233 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -29,7 +29,7 @@ import dotty.tools.dotc.core.Annotations.Annotation class Setup( preRecheckPhase: DenotTransformer, thisPhase: DenotTransformer, - recheckDef: (tpd.ValOrDefDef, Symbol) => Context ?=> Unit) + checker: CheckCaptures#CaptureChecker) extends tpd.TreeTraverser: import tpd.* @@ -106,7 +106,7 @@ extends tpd.TreeTraverser: cls.paramGetters.foldLeft(tp) { (core, getter) => if getter.termRef.isTracked then val getterType = tp.memberInfo(getter).strippedDealias - RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var())) + RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(ctx.owner))) .showing(i"add capture refinement $tp --> $result", capt) else core @@ -164,7 +164,7 @@ extends tpd.TreeTraverser: resType = this(tp.resType)) case _ => mapOver(tp) - Setup.addVar(addCaptureRefinements(tp1)) + Setup.addVar(addCaptureRefinements(tp1), ctx.owner) end apply end mapInferred @@ -238,6 +238,7 @@ extends tpd.TreeTraverser: */ private class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context) extends DeepTypeMap, BiTypeMap: + thisMap => def apply(t: Type): Type = t match case t: NamedType => @@ -253,15 +254,17 @@ extends tpd.TreeTraverser: case _ => mapOver(t) - def inverse(t: Type): Type = t match - case t: ParamRef => - def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = - if from.isEmpty then t - else if t.binder eq from.head then to.head(t.paramNum).namedType - else recur(from.tail, to.tail) - recur(to, from) - case _ => - mapOver(t) + lazy val inverse = new BiTypeMap: + def apply(t: Type): Type = t match + case t: ParamRef => + def recur(from: List[LambdaType], to: List[List[Symbol]]): Type = + if from.isEmpty then t + else if t.binder eq from.head then to.head(t.paramNum).namedType + else recur(from.tail, to.tail) + recur(to, from) + case _ => + mapOver(t) + def inverse = thisMap end SubstParams /** Update info of `sym` for CheckCaptures phase only */ @@ -273,14 +276,15 @@ extends tpd.TreeTraverser: case tree: DefDef => if isExcluded(tree.symbol) then return - tree.tpt match - case tpt: TypeTree if tree.symbol.allOverriddenSymbols.hasNext => - tree.paramss.foreach(traverse) - transformTT(tpt, boxed = false, exact = true) - traverse(tree.rhs) - //println(i"TYPE of ${tree.symbol.showLocated} = ${tpt.knownType}") - case _ => - traverseChildren(tree) + inContext(ctx.withOwner(tree.symbol)): + tree.tpt match + case tpt: TypeTree if tree.symbol.allOverriddenSymbols.hasNext => + tree.paramss.foreach(traverse) + transformTT(tpt, boxed = false, exact = true) + traverse(tree.rhs) + //println(i"TYPE of ${tree.symbol.showLocated} = ${tpt.knownType}") + case _ => + traverseChildren(tree) case tree @ ValDef(_, tpt: TypeTree, _) => transformTT(tpt, boxed = tree.symbol.is(Mutable), // types of mutable variables are boxed @@ -308,81 +312,99 @@ extends tpd.TreeTraverser: i"Sealed type variable $pname", "be instantiated to", i"This is often caused by a local capability$where\nleaking as part of its result.", tree.srcPos) + case tree: Template => + inContext(ctx.withOwner(tree.symbol.owner)): + traverseChildren(tree) case _ => traverseChildren(tree) - tree match - case tree: TypeTree => - transformTT(tree, boxed = false, exact = false) // other types are not boxed - case tree: ValOrDefDef => - val sym = tree.symbol - - // replace an existing symbol info with inferred types where capture sets of - // TypeParamRefs and TermParamRefs put in correspondence by BiTypeMaps with the - // capture sets of the types of the method's parameter symbols and result type. - def integrateRT( - info: Type, // symbol info to replace - psymss: List[List[Symbol]], // the local (type and term) parameter symbols corresponding to `info` - prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order - prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order - ): Type = - info match - case mt: MethodOrPoly => - val psyms = psymss.head - mt.companion(mt.paramNames)( - mt1 => - if !psyms.exists(_.isUpdatedAfter(preRecheckPhase)) && !mt.isParamDependent && prevLambdas.isEmpty then - mt.paramInfos - else - val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) - psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]), - mt1 => - integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) - ) - case info: ExprType => - info.derivedExprType(resType = - integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) - case _ => - val restp = tree.tpt.knownType - if prevLambdas.isEmpty then restp - else SubstParams(prevPsymss, prevLambdas)(restp) - - if sym.exists && tree.tpt.hasRememberedType && !sym.isConstructor then - val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) - .showing(i"update info $sym: ${sym.info} --> $result", capt) - if newInfo ne sym.info then - val completer = new LazyType: + postProcess(tree) + end traverse + + def postProcess(tree: Tree)(using Context): Unit = tree match + case tree: TypeTree => + transformTT(tree, boxed = false, exact = false) // other types are not boxed + case tree: ValOrDefDef => + val sym = tree.symbol + + def remapRoot(tp: Type): Type = mapRoots(sym.levelOwner).inverse(tp) + + // replace an existing symbol info with inferred types where capture sets of + // TypeParamRefs and TermParamRefs put in correspondence by BiTypeMaps with the + // capture sets of the types of the method's parameter symbols and result type. + def integrateRT( + info: Type, // symbol info to replace + psymss: List[List[Symbol]], // the local (type and term) parameter symbols corresponding to `info` + prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order + prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order + ): Type = + info match + case mt: MethodOrPoly => + val psyms = psymss.head + mt.companion(mt.paramNames)( + mt1 => + if !psyms.exists(_.isUpdatedAfter(preRecheckPhase)) && !mt.isParamDependent && prevLambdas.isEmpty then + mt.paramInfos + else + val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas) + psyms.map(psym => remapRoot(subst(psym.info)).asInstanceOf[mt.PInfo]), + mt1 => + integrateRT(remapRoot(mt.resType), psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas) + ) + case info: ExprType => + info.derivedExprType(resType = + integrateRT(info.resType, psymss, prevPsymss, prevLambdas)) + case _ => + val restp = tree.tpt.knownType + if prevLambdas.isEmpty then restp + else SubstParams(prevPsymss, prevLambdas)(restp) + + if sym.exists && tree.tpt.hasRememberedType && !sym.isConstructor then + val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil) + .showing(i"update info $sym: ${sym.info} --> $result", capt) + if newInfo ne sym.info then + updateInfo(sym, + if sym.isAnonymousFunction then + // closures are handled specially; the newInfo is constrained from + // the expected type and only afterwards we recheck the definition + newInfo + else new LazyType: def complete(denot: SymDenotation)(using Context) = + // infos other methods are determined from their definitions which + // are checked on depand denot.info = newInfo - recheckDef(tree, sym) - updateInfo(sym, completer) - case tree: Bind => - val sym = tree.symbol - updateInfo(sym, transformInferredType(sym.info, boxed = false)) - case tree: TypeDef => - tree.symbol match - case cls: ClassSymbol => - val cinfo @ ClassInfo(prefix, _, ps, decls, selfInfo) = cls.classInfo - if (selfInfo eq NoType) || cls.is(ModuleClass) && !cls.isStatic then - // add capture set to self type of nested classes if no self type is given explicitly - val localRefs = CaptureSet.Var() - val newInfo = ClassInfo(prefix, cls, ps, decls, - CapturingType(cinfo.selfType, localRefs) - .showing(i"inferred self type for $cls: $result", capt)) - updateInfo(cls, newInfo) - cls.thisType.asInstanceOf[ThisType].invalidateCaches() - if cls.is(ModuleClass) then - // if it's a module, the capture set of the module reference is the capture set of the self type - val modul = cls.sourceModule - updateInfo(modul, CapturingType(modul.info, localRefs)) - modul.termRef.invalidateCaches() - case _ => - val info = atPhase(preRecheckPhase)(tree.symbol.info) - val newInfo = transformExplicitType(info, boxed = false) - if newInfo ne info then - updateInfo(tree.symbol, newInfo) - capt.println(i"update info of ${tree.symbol} from $info to $newInfo") - case _ => - end traverse + checker.recheckDef(tree, sym)) + case tree: Bind => + val sym = tree.symbol + updateInfo(sym, transformInferredType(sym.info, boxed = false)) + case tree: TypeDef => + tree.symbol match + case cls: ClassSymbol => + val cinfo @ ClassInfo(prefix, _, ps, decls, selfInfo) = cls.classInfo + if (selfInfo eq NoType) || cls.is(ModuleClass) && !cls.isStatic then + // add capture set to self type of nested classes if no self type is given explicitly + val selfRefs = CaptureSet.Var(cls) + // it's debatable what the right level owner should be. A self type should + // be able to mention class parameters, which are owned by the class; that's + // why the class was picked as level owner. But self types should not be able + // to mention other fields. + val newInfo = ClassInfo(prefix, cls, ps, decls, + CapturingType(cinfo.selfType, selfRefs) + .showing(i"inferred self type for $cls: $result", capt)) + updateInfo(cls, newInfo) + cls.thisType.asInstanceOf[ThisType].invalidateCaches() + if cls.is(ModuleClass) then + // if it's a module, the capture set of the module reference is the capture set of the self type + val modul = cls.sourceModule + updateInfo(modul, CapturingType(modul.info, selfRefs)) + modul.termRef.invalidateCaches() + case _ => + val info = atPhase(preRecheckPhase)(tree.symbol.info) + val newInfo = transformExplicitType(info, boxed = false) + if newInfo ne info then + updateInfo(tree.symbol, newInfo) + capt.println(i"update info of ${tree.symbol} from $info to $newInfo") + case _ => + end postProcess def apply(tree: Tree)(using Context): Unit = traverse(tree)(using ctx.withProperty(Setup.IsDuringSetupKey, Some(()))) @@ -479,10 +501,10 @@ object Setup: /** Add a capture set variable to `tp` if necessary, or maybe pull out * an embedded capture set variable from a part of `tp`. */ - def addVar(tp: Type)(using Context): Type = + def addVar(tp: Type, owner: Symbol)(using Context): Type = decorate(tp, addedSet = _.dealias.match - case CapturingType(_, refs) => CaptureSet.Var(refs.elems) - case _ => CaptureSet.Var()) + case CapturingType(_, refs) => CaptureSet.Var(owner, refs.elems) + case _ => CaptureSet.Var(owner)) end Setup \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 8a7f2ff4e051..f0a1453672e0 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -815,7 +815,7 @@ object Contexts { * Note: plain TypeComparers always take on the kind of the outer comparer if they are in the same context. * In other words: tracking or explaining is a sticky property in the same context. */ - private def comparer(using Context): TypeComparer = + def comparer(using Context): TypeComparer = util.Stats.record("comparing") val base = ctx.base if base.comparersInUse > 0 diff --git a/compiler/src/dotty/tools/dotc/core/Denotations.scala b/compiler/src/dotty/tools/dotc/core/Denotations.scala index a478d60ce348..640ba8015be7 100644 --- a/compiler/src/dotty/tools/dotc/core/Denotations.scala +++ b/compiler/src/dotty/tools/dotc/core/Denotations.scala @@ -884,7 +884,6 @@ object Denotations { /** Install this denotation to be the result of the given denotation transformer. * This is the implementation of the same-named method in SymDenotations. * It's placed here because it needs access to private fields of SingleDenotation. - * @pre Can only be called in `phase.next`. */ protected def installAfter(phase: DenotTransformer)(using Context): Unit = { val targetId = phase.next.id @@ -892,16 +891,21 @@ object Denotations { else { val current = symbol.current // println(s"installing $this after $phase/${phase.id}, valid = ${current.validFor}") - // printPeriods(current) + // println(current.definedPeriodsString) this.validFor = Period(ctx.runId, targetId, current.validFor.lastPhaseId) if (current.validFor.firstPhaseId >= targetId) current.replaceWith(this) + symbol.denot + // Let symbol point to updated denotation + // Without this we can get problems when we immediately recompute the denotation + // at another phase since the invariant that symbol used to point to a valid + // denotation is lost. else { current.validFor = Period(ctx.runId, current.validFor.firstPhaseId, targetId - 1) insertAfter(current) } + // println(current.definedPeriodsString) } - // printPeriods(this) } /** Apply a transformation `f` to all denotations in this group that start at or after diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index cd51d4bf79c2..8f99d4f5a240 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -287,6 +287,7 @@ object StdNames { // Compiler-internal val CAPTURE_ROOT: N = "cap" + val LOCAL_CAPTURE_ROOT: N = "" val CONSTRUCTOR: N = "" val STATIC_CONSTRUCTOR: N = "" val EVT2U: N = "evt2u$" diff --git a/compiler/src/dotty/tools/dotc/core/Substituters.scala b/compiler/src/dotty/tools/dotc/core/Substituters.scala index 3e32340b21bd..5a641416b3e1 100644 --- a/compiler/src/dotty/tools/dotc/core/Substituters.scala +++ b/compiler/src/dotty/tools/dotc/core/Substituters.scala @@ -165,7 +165,7 @@ object Substituters: final class SubstBindingMap(from: BindingType, to: BindingType)(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = subst(tp, from, to, this)(using mapCtx) - def inverse(tp: Type): Type = tp.subst(to, from) + def inverse = SubstBindingMap(to, from) } final class Subst1Map(from: Symbol, to: Type)(using Context) extends DeepTypeMap { @@ -182,7 +182,7 @@ object Substituters: final class SubstSymMap(from: List[Symbol], to: List[Symbol])(using Context) extends DeepTypeMap, BiTypeMap { def apply(tp: Type): Type = substSym(tp, from, to, this)(using mapCtx) - def inverse(tp: Type) = tp.substSym(to, from) // implicitly requires that `to` contains no duplicates. + def inverse = SubstSymMap(to, from) // implicitly requires that `to` contains no duplicates. } final class SubstThisMap(from: ClassSymbol, to: Type)(using Context) extends DeepTypeMap { diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index e763f6c7a303..164b47793969 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -676,6 +676,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp1: RefinedType => return isSubInfo(tp1.refinedInfo, tp2.refinedInfo) case _ => + end if val skipped2 = skipMatching(tp1w, tp2) if (skipped2 eq tp2) || !Config.fastPathForRefinedSubtype then @@ -3129,6 +3130,9 @@ object TypeComparer { def tracked[T](op: TrackingTypeComparer => T)(using Context): T = comparing(_.tracked(op)) + + def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult = + comparing(_.subCaptures(refs1, refs2, frozen)) } object TrackingTypeComparer: diff --git a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala index 24a207da6836..1dcd2301b1a7 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErrors.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErrors.scala @@ -177,7 +177,7 @@ object CyclicReference: def apply(denot: SymDenotation)(using Context): CyclicReference = val ex = new CyclicReference(denot) if ex.computeStackTrace then - cyclicErrors.println(s"Cyclic reference involving! $denot") + cyclicErrors.println(s"Cyclic reference involving $denot") val sts = ex.getStackTrace.asInstanceOf[Array[StackTraceElement]] for (elem <- sts take 200) cyclicErrors.println(elem.toString) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index d68ab1aedf49..44554cc251b8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -36,7 +36,7 @@ import config.Printers.{core, typr, matchTypes} import reporting.{trace, Message} import java.lang.ref.WeakReference import compiletime.uninitialized -import cc.{CapturingType, CaptureSet, derivedCapturingType, isBoxedCapturing, EventuallyCapturingType, boxedUnlessFun} +import cc.{CapturingType, CaptureSet, derivedCapturingType, isBoxedCapturing, EventuallyCapturingType, boxedUnlessFun, ccNestingLevel} import CaptureSet.{CompareResult, IdempotentCaptRefMap, IdentityCaptRefMap} import scala.annotation.internal.sharable @@ -478,6 +478,11 @@ object Types { */ def isDeclaredVarianceLambda: Boolean = false + /** Is this type a CaptureRef that can be tracked? + * This is true for all ThisTypes or ParamRefs but only for some NamedTypes. + */ + def isTrackableRef(using Context): Boolean = false + /** Does this type contain wildcard types? */ final def containsWildcardTypes(using Context) = existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) @@ -2149,18 +2154,27 @@ object Types { private var myCaptureSetRunId: Int = NoRunId private var mySingletonCaptureSet: CaptureSet.Const | Null = null - /** Can the reference be tracked? This is true for all ThisTypes or ParamRefs - * but only for some NamedTypes. - */ - def canBeTracked(using Context): Boolean - /** Is the reference tracked? This is true if it can be tracked and the capture * set of the underlying type is not always empty. */ - final def isTracked(using Context): Boolean = canBeTracked && !captureSetOfInfo.isAlwaysEmpty + final def isTracked(using Context): Boolean = isTrackableRef && !captureSetOfInfo.isAlwaysEmpty - /** Is this reference the root capability `cap` ? */ - def isRootCapability(using Context): Boolean = false + /** Is this reference the generic root capability `cap` ? */ + def isGenericRootCapability(using Context): Boolean = false + + /** Is this reference a local root capability `{}` + * for some level owner? + */ + final def isLocalRootCapability(using Context): Boolean = + localRootOwner.exists + + /** If this is a local root capability, its owner, otherwise NoSymbol. + */ + def localRootOwner(using Context): Symbol = NoSymbol + + /** Is this reference the a (local or generic) root capability? */ + def isRootCapability(using Context): Boolean = + isGenericRootCapability || isLocalRootCapability /** Normalize reference so that it can be compared with `eq` for equality */ def normalizedRef(using Context): CaptureRef = this @@ -2190,7 +2204,9 @@ object Types { override def captureSet(using Context): CaptureSet = val cs = captureSetOfInfo - if canBeTracked && !cs.isAlwaysEmpty then singletonCaptureSet else cs + if isTrackableRef && !cs.isAlwaysEmpty then singletonCaptureSet else cs + + def ccNestingLevel(using Context): Int end CaptureRef /** A trait for types that bind other types that refer to them. @@ -2887,17 +2903,22 @@ object Types { * They are subsumed in the capture sets of the enclosing class. * TODO: ^^^ What about call-by-name? */ - def canBeTracked(using Context) = + override def isTrackableRef(using Context) = ((prefix eq NoPrefix) || symbol.is(ParamAccessor) && (prefix eq symbol.owner.thisType) || isRootCapability ) && !symbol.isOneOf(UnstableValueFlags) - override def isRootCapability(using Context): Boolean = + override def isGenericRootCapability(using Context): Boolean = name == nme.CAPTURE_ROOT && symbol == defn.captureRoot + override def localRootOwner(using Context): Symbol = + if name == nme.LOCAL_CAPTURE_ROOT then symbol.owner else NoSymbol + override def normalizedRef(using Context): CaptureRef = - if canBeTracked then symbol.termRef else this + if isTrackableRef then symbol.termRef else this + + def ccNestingLevel(using Context) = symbol.ccNestingLevel } abstract case class TypeRef(override val prefix: Type, @@ -3050,7 +3071,7 @@ object Types { // can happen in IDE if `cls` is stale } - def canBeTracked(using Context) = true + override def isTrackableRef(using Context) = true override def computeHash(bs: Binders): Int = doHash(bs, tref) @@ -3064,6 +3085,8 @@ object Types { def sameThis(that: Type)(using Context): Boolean = (that eq this) || that.match case that: ThisType => this.cls eq that.cls case _ => false + + def ccNestingLevel(using Context) = cls.ccNestingLevel } final class CachedThisType(tref: TypeRef) extends ThisType(tref) @@ -4661,9 +4684,10 @@ object Types { */ abstract case class TermParamRef(binder: TermLambda, paramNum: Int) extends ParamRef, CaptureRef { type BT = TermLambda - def canBeTracked(using Context) = true def kindString: String = "Term" def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum) + override def isTrackableRef(using Context) = true + def ccNestingLevel(using Context) = 0 // !!! Is this the right level? } private final class TermParamRefImpl(binder: TermLambda, paramNum: Int) extends TermParamRef(binder, paramNum) @@ -5716,23 +5740,16 @@ object Types { trait BiTypeMap extends TypeMap: thisMap => - /** The inverse of the type map as a function */ - def inverse(tp: Type): Type - - /** The inverse of the type map as a BiTypeMap map, which - * has the original type map as its own inverse. - */ - def inverseTypeMap(using Context) = new BiTypeMap: - def apply(tp: Type) = thisMap.inverse(tp) - def inverse(tp: Type) = thisMap.apply(tp) + /** The inverse of the type map */ + def inverse: BiTypeMap /** A restriction of this map to a function on tracked CaptureRefs */ def forward(ref: CaptureRef): CaptureRef = this(ref) match - case result: CaptureRef if result.canBeTracked => result + case result: CaptureRef if result.isTrackableRef => result /** A restriction of the inverse to a function on tracked CaptureRefs */ def backward(ref: CaptureRef): CaptureRef = inverse(ref) match - case result: CaptureRef if result.canBeTracked => result + case result: CaptureRef if result.isTrackableRef => result end BiTypeMap abstract class TypeMap(implicit protected var mapCtx: Context) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 6cba2f78776b..cb1f7880a8aa 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -15,7 +15,7 @@ import util.SourcePosition import scala.util.control.NonFatal import scala.annotation.switch import config.{Config, Feature} -import cc.{CapturingType, EventuallyCapturingType, CaptureSet, isBoxed} +import cc.{CapturingType, EventuallyCapturingType, CaptureSet, isBoxed, ccNestingLevel} class PlainPrinter(_ctx: Context) extends Printer { @@ -149,12 +149,16 @@ class PlainPrinter(_ctx: Context) extends Printer { + defn.ObjectClass + defn.FromJavaObjectSymbol - def toTextCaptureSet(cs: CaptureSet): Text = + def toTextCaptureSet(cs: CaptureSet, describe: Boolean): Text = + def descr = Str(cs.description).provided(describe) + ~ cs.match + case cs: CaptureSet.Var if showNestingLevel => s"" + case _ => "" if printDebug && !cs.isConst then cs.toString else if ctx.settings.YccDebug.value then cs.show else if cs == CaptureSet.Fluid then "" - else if !cs.isConst && cs.elems.isEmpty then "?" - else "{" ~ Text(cs.elems.toList.map(toTextCaptureRef), ", ") ~ "}" + else if !cs.isConst && cs.elems.isEmpty then Str("?") ~~ descr + else "{" ~ Text(cs.elems.toList.map(toTextCaptureRef), ", ") ~ "}" ~~ descr /** Print capturing type, overridden in RefinedPrinter to account for * capturing function types. @@ -222,7 +226,7 @@ class PlainPrinter(_ctx: Context) extends Printer { }.close case tp @ EventuallyCapturingType(parent, refs) => val boxText: Text = Str("box ") provided tp.isBoxed //&& ctx.settings.YccDebug.value - val refsText = if refs.isUniversal then rootSetText else toTextCaptureSet(refs) + val refsText = if refs.isUniversal then rootSetText else toTextCaptureSet(refs, describe = false) toTextCapturing(parent, refsText, boxText) case tp: PreviousErrorType if ctx.settings.XprintTypes.value => "" // do not print previously reported error message because they may try to print this error type again recuresevely @@ -247,7 +251,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case ExprType(restp) => def arrowText: Text = restp match case ct @ EventuallyCapturingType(parent, refs) if ct.annot.symbol == defn.RetainsByNameAnnot => - if refs.isUniversal then Str("=>") else Str("->") ~ toTextCaptureSet(refs) + if refs.isUniversal then Str("=>") else Str("->") ~ toTextCaptureSet(refs, describe = false) case _ => if Feature.pureFunsEnabled then "->" else "=>" changePrec(GlobalPrec)(arrowText ~ " " ~ toText(restp)) @@ -354,7 +358,8 @@ class PlainPrinter(_ctx: Context) extends Printer { def toTextRef(tp: SingletonType): Text = controlled { tp match { case tp: TermRef => - toTextPrefixOf(tp) ~ selectionString(tp) + if tp.isLocalRootCapability then Str(s"") + else toTextPrefixOf(tp) ~ selectionString(tp) case tp: ThisType => nameString(tp.cls) + ".this" case SuperType(thistpe: SingletonType, _) => diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index b3e48ab2d843..0ddaccec47c8 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -10,6 +10,7 @@ import Types.{Type, SingletonType, LambdaParam, NamedType}, import typer.Implicits.SearchResult import util.SourcePosition import typer.ImportInfo +import cc.CaptureSet import scala.annotation.internal.sharable @@ -106,6 +107,9 @@ abstract class Printer { /** Textual representation of a reference in a capture set */ def toTextCaptureRef(tp: Type): Text + /** Textual representation of a reference in a capture set */ + def toTextCaptureSet(cs: CaptureSet, describe: Boolean = true): Text + /** Textual representation of symbol's declaration */ def dclText(sym: Symbol): Text diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index ff09a6084136..72c9dd6ac994 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -29,7 +29,7 @@ import config.{Config, Feature} import dotty.tools.dotc.util.SourcePosition import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef} -import cc.{CaptureSet, toCaptureSet, IllegalCaptureRef} +import cc.{CaptureSet, toCaptureSet, IllegalCaptureRef, ccNestingLevelOpt} class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { @@ -634,7 +634,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def toTextAnnot = toTextLocal(arg) ~~ annotText(annot.symbol.enclosingClass, annot) def toTextRetainsAnnot = - try changePrec(GlobalPrec)(toText(arg) ~ "^" ~ toTextCaptureSet(captureSet)) + try changePrec(GlobalPrec)(toText(arg) ~ "^" ~ toTextCaptureSet(captureSet, describe = false)) catch case ex: IllegalCaptureRef => toTextAnnot if annot.symbol.maybeOwner == defn.RetainsAnnot && Feature.ccEnabled && Config.printCaptureSetsAsPrefix && !printDebug @@ -865,10 +865,13 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { protected def optAscription[T <: Untyped](tpt: Tree[T]): Text = optText(tpt)(": " ~ _) + private def nestingLevel(sym: Symbol): Int = + sym.ccNestingLevelOpt.getOrElse(sym.nestingLevel) + private def idText(tree: untpd.Tree): Text = (if showUniqueIds && tree.hasType && tree.symbol.exists then s"#${tree.symbol.id}" else "") ~ (if showNestingLevel then tree.typeOpt match - case tp: NamedType if !tp.symbol.isStatic => s"%${tp.symbol.nestingLevel}" + case tp: NamedType if !tp.symbol.isStatic => s"%${nestingLevel(tp.symbol)}" case tp: TypeVar => s"%${tp.nestingLevel}" case tp: TypeParamRef => ctx.typerState.constraint.typeVarOfParam(tp) match case tvar: TypeVar => s"%${tvar.nestingLevel}" diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index 01aae8f8da18..bac78f39d8c8 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -15,6 +15,7 @@ import typer.ErrorReporting.err import typer.ProtoTypes.* import typer.TypeAssigner.seqLitType import typer.ConstFold +import typer.ErrorReporting.{Addenda, NothingToAdd} import NamerOps.methodType import config.Printers.recheckr import util.Property @@ -52,17 +53,18 @@ object Recheck: */ def updateInfoBetween(prevPhase: DenotTransformer, lastPhase: DenotTransformer, newInfo: Type)(using Context): Unit = if sym.info ne newInfo then + val flags = sym.flags sym.copySymDenotation( initFlags = - if sym.flags.isAllOf(ResetPrivateParamAccessor) - then sym.flags &~ ResetPrivate | Private - else sym.flags + if flags.isAllOf(ResetPrivateParamAccessor) + then flags &~ ResetPrivate | Private + else flags ).installAfter(lastPhase) // reset sym.copySymDenotation( info = newInfo, initFlags = - if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched - else sym.flags + if newInfo.isInstanceOf[LazyType] then flags &~ Touched + else flags ).installAfter(prevPhase) /** Does symbol have a new denotation valid from phase.next that is different @@ -96,17 +98,44 @@ object Recheck: case Some(tpe) => tree.withType(tpe).asInstanceOf[T] case None => tree - extension (tpe: Type) - - /** Map ExprType => T to () ?=> T (and analogously for pure versions). - * Even though this phase runs after ElimByName, ExprTypes can still occur - * as by-name arguments of applied types. See note in doc comment for - * ElimByName phase. Test case is bynamefun.scala. - */ - def mapExprType(using Context): Type = tpe match - case ExprType(rt) => defn.ByNameFunction(rt) - case _ => tpe + /** Map ExprType => T to () ?=> T (and analogously for pure versions). + * Even though this phase runs after ElimByName, ExprTypes can still occur + * as by-name arguments of applied types. See note in doc comment for + * ElimByName phase. Test case is bynamefun.scala. + */ + private def mapExprType(tp: Type)(using Context): Type = tp match + case ExprType(rt) => defn.ByNameFunction(rt) + case _ => tp + + /** Normalize `=> A` types to `() ?=> A` types + * - at the top level + * - in function and method parameter types + * - under annotations + */ + def normalizeByName(tp: Type)(using Context): Type = tp match + case tp: ExprType => + mapExprType(tp) + case tp: PolyType => + tp.derivedLambdaType(resType = normalizeByName(tp.resType)) + case tp: MethodType => + tp.derivedLambdaType( + paramInfos = tp.paramInfos.mapConserve(mapExprType), + resType = normalizeByName(tp.resType)) + case tp @ RefinedType(parent, nme.apply, rinfo) if defn.isFunctionType(tp) => + tp.derivedRefinedType(parent, nme.apply, normalizeByName(rinfo)) + case tp @ defn.FunctionOf(pformals, restpe, isContextual) => + val pformals1 = pformals.mapConserve(mapExprType) + val restpe1 = normalizeByName(restpe) + if (pformals1 ne pformals) || (restpe1 ne restpe) then + defn.FunctionOf(pformals1, restpe1, isContextual) + else + tp + case tp @ AnnotatedType(parent, ann) => + tp.derivedAnnotatedType(normalizeByName(parent), ann) + case _ => + tp +end Recheck /** A base class that runs a simplified typer pass over an already re-typed program. The pass * does not transform trees but returns instead the re-typed type of each tree as it is @@ -183,27 +212,16 @@ abstract class Recheck extends Phase, SymTransformer: else AnySelectionProto recheckSelection(tree, recheck(qual, proto).widenIfUnstable, name, pt) - /** When we select the `apply` of a function with type such as `(=> A) => B`, - * we need to convert the parameter type `=> A` to `() ?=> A`. See doc comment - * of `mapExprType`. - */ - def normalizeByName(mbr: SingleDenotation)(using Context): SingleDenotation = mbr.info match - case mt: MethodType if mt.paramInfos.exists(_.isInstanceOf[ExprType]) => - mbr.derivedSingleDenotation(mbr.symbol, - mt.derivedLambdaType(paramInfos = mt.paramInfos.map(_.mapExprType))) - case _ => - mbr - def recheckSelection(tree: Select, qualType: Type, name: Name, sharpen: Denotation => Denotation)(using Context): Type = if name.is(OuterSelectName) then tree.tpe else //val pre = ta.maybeSkolemizePrefix(qualType, name) - val mbr = normalizeByName( + val mbr = sharpen( qualType.findMember(name, qualType, excluded = if tree.symbol.is(Private) then EmptyFlags else Private - )).suchThat(tree.symbol == _)) + )).suchThat(tree.symbol == _) val newType = tree.tpe match case prevType: NamedType => val prevDenot = prevType.denot @@ -268,12 +286,15 @@ abstract class Recheck extends Phase, SymTransformer: protected def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type = mt.instantiate(argTypes) + protected def prepareFunction(funtpe: MethodType)(using Context): MethodType = funtpe + def recheckApply(tree: Apply, pt: Type)(using Context): Type = val funTp = recheck(tree.fun) // reuse the tree's type on signature polymorphic methods, instead of using the (wrong) rechecked one val funtpe = if tree.fun.symbol.originalSignaturePolymorphic.exists then tree.fun.tpe else funTp funtpe.widen match - case fntpe: MethodType => + case fntpe0: MethodType => + val fntpe = prepareFunction(fntpe0) assert(fntpe.paramInfos.hasSameLengthAs(tree.args)) val formals = if false && tree.symbol.is(JavaDefined) // see NOTE in mapJavaArgs @@ -281,7 +302,7 @@ abstract class Recheck extends Phase, SymTransformer: else fntpe.paramInfos def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match case arg :: args1 => - val argType = recheck(arg, formals.head.mapExprType) + val argType = recheck(arg, normalizeByName(formals.head)) val formals1 = if fntpe.isParamDependent then formals.tail.map(_.substParam(prefs.head, argType)) @@ -313,27 +334,33 @@ abstract class Recheck extends Phase, SymTransformer: recheck(tree.rhs, lhsType.widen) defn.UnitType - def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type = + private def recheckBlock(stats: List[Tree], expr: Tree)(using Context): Type = recheckStats(stats) val exprType = recheck(expr) + TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) + + def recheckBlock(tree: Block, pt: Type)(using Context): Type = tree match + case Block(Nil, expr: Block) => recheckBlock(expr, pt) + case Block((mdef : DefDef) :: Nil, closure: Closure) => + recheckClosureBlock(mdef, closure.withSpan(tree.span), pt) + case Block(stats, expr) => recheckBlock(stats, expr) // The expected type `pt` is not propagated. Doing so would allow variables in the // expected type to contain references to local symbols of the block, so the // local symbols could escape that way. - TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm)) - def recheckBlock(tree: Block, pt: Type)(using Context): Type = - recheckBlock(tree.stats, tree.expr, pt) + def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type = + recheckBlock(mdef :: Nil, expr) def recheckInlined(tree: Inlined, pt: Type)(using Context): Type = - recheckBlock(tree.bindings, tree.expansion, pt)(using inlineContext(tree)) + recheckBlock(tree.bindings, tree.expansion)(using inlineContext(tree)) def recheckIf(tree: If, pt: Type)(using Context): Type = recheck(tree.cond, defn.BooleanType) recheck(tree.thenp, pt) | recheck(tree.elsep, pt) - def recheckClosure(tree: Closure, pt: Type)(using Context): Type = + def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean = false)(using Context): Type = if tree.tpt.isEmpty then - tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined)) + tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined), alwaysDependent = forceDependent) else recheck(tree.tpt) @@ -534,13 +561,11 @@ abstract class Recheck extends Phase, SymTransformer: /** Check that widened types of `tpe` and `pt` are compatible. */ def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match - case _: DefTree | EmptyTree | _: TypeTree | _: Closure => - // Don't report closure nodes, since their span is a point; wait instead - // for enclosing block to preduce an error + case _: DefTree | EmptyTree | _: TypeTree => case _ => checkConformsExpr(tpe.widenExpr, pt.widenExpr, tree) - def checkConformsExpr(actual: Type, expected: Type, tree: Tree)(using Context): Unit = + def checkConformsExpr(actual: Type, expected: Type, tree: Tree, addenda: Addenda = NothingToAdd)(using Context): Unit = //println(i"check conforms $actual <:< $expected") def isCompatible(expected: Type): Boolean = @@ -553,7 +578,7 @@ abstract class Recheck extends Phase, SymTransformer: } if !isCompatible(expected) then recheckr.println(i"conforms failed for ${tree}: $actual vs $expected") - err.typeMismatch(tree.withType(actual), expected) + err.typeMismatch(tree.withType(actual), expected, addenda) else if debugSuccesses then tree match case _: Ident => diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 25cbfdfec600..68ea402eff3f 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -70,6 +70,15 @@ object ErrorReporting { case _ => foldOver(s, tp) tps.foldLeft("")(collectMatchTrace) + /** A mixin trait that can produce added elements for an error message */ + trait Addenda: + self => + def toAdd(using Context): List[String] = Nil + def ++ (follow: Addenda) = new Addenda: + override def toAdd(using Context) = self.toAdd ++ follow.toAdd + + object NothingToAdd extends Addenda + class Errors(using Context) { /** An explanatory note to be added to error messages @@ -162,7 +171,7 @@ object ErrorReporting { def patternConstrStr(tree: Tree): String = ??? - def typeMismatch(tree: Tree, pt: Type, implicitFailure: SearchFailureType = NoMatchingImplicits): Tree = { + def typeMismatch(tree: Tree, pt: Type, addenda: Addenda = NothingToAdd): Tree = { val normTp = normalize(tree.tpe, pt) val normPt = normalize(pt, pt) @@ -184,7 +193,7 @@ object ErrorReporting { "\nMaybe you are missing an else part for the conditional?" case _ => "" - errorTree(tree, TypeMismatch(treeTp, expectedTp, Some(tree), implicitFailure.whyNoConversion, missingElse)) + errorTree(tree, TypeMismatch(treeTp, expectedTp, Some(tree), (addenda.toAdd :+ missingElse)*)) } /** A subtype log explaining why `found` does not conform to `expected` */ diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index e576c6363e39..65352735beda 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -440,7 +440,7 @@ object Implicits: } } - abstract class SearchFailureType extends ErrorType { + abstract class SearchFailureType extends ErrorType, Addenda { def expectedType: Type def argument: Tree @@ -457,11 +457,6 @@ object Implicits: if (argument.isEmpty) i"match type ${clarify(expectedType)}" else i"convert from ${argument.tpe} to ${clarify(expectedType)}" } - - /** If search was for an implicit conversion, a note describing the failure - * in more detail - this is either empty or starts with a '\n' - */ - def whyNoConversion(using Context): String = "" } class NoMatchingImplicits(val expectedType: Type, val argument: Tree, constraint: Constraint = OrderingConstraint.empty) @@ -515,17 +510,21 @@ object Implicits: /** A failure value indicating that an implicit search for a conversion was not tried */ case class TooUnspecific(target: Type) extends NoMatchingImplicits(NoType, EmptyTree, OrderingConstraint.empty): - override def whyNoConversion(using Context): String = + + override def toAdd(using Context) = i""" |Note that implicit conversions were not tried because the result of an implicit conversion - |must be more specific than $target""" + |must be more specific than $target""" :: Nil override def msg(using Context) = super.msg.append("\nThe expected type $target is not specific enough, so no search was attempted") + override def toString = s"TooUnspecific" + end TooUnspecific /** An ambiguous implicits failure */ - class AmbiguousImplicits(val alt1: SearchSuccess, val alt2: SearchSuccess, val expectedType: Type, val argument: Tree) extends SearchFailureType { + class AmbiguousImplicits(val alt1: SearchSuccess, val alt2: SearchSuccess, val expectedType: Type, val argument: Tree) extends SearchFailureType: + def msg(using Context): Message = var str1 = err.refStr(alt1.ref) var str2 = err.refStr(alt2.ref) @@ -533,15 +532,16 @@ object Implicits: str1 = ctx.printer.toTextRef(alt1.ref).show str2 = ctx.printer.toTextRef(alt2.ref).show em"both $str1 and $str2 $qualify".withoutDisambiguation() - override def whyNoConversion(using Context): String = + + override def toAdd(using Context) = if !argument.isEmpty && argument.tpe.widen.isRef(defn.NothingClass) then - "" + Nil else val what = if (expectedType.isInstanceOf[SelectionProto]) "extension methods" else "conversions" i""" |Note that implicit $what cannot be applied because they are ambiguous; - |$explanation""" - } + |$explanation""" :: Nil + end AmbiguousImplicits class MismatchedImplicit(ref: TermRef, val expectedType: Type, diff --git a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala index dd766dc99c7e..d7e1a60f56fa 100644 --- a/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala +++ b/compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala @@ -14,7 +14,10 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] { def contains[E >: Elem <: AnyRef](x: E): Boolean def foreach(f: Elem => Unit): Unit def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean - def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] + def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + var acc: SimpleIdentitySet[B] = SimpleIdentitySet.empty + foreach(x => acc += f(x)) + acc def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A def toList: List[Elem] def iterator: Iterator[Elem] @@ -63,7 +66,7 @@ object SimpleIdentitySet { def contains[E <: AnyRef](x: E): Boolean = false def foreach(f: Nothing => Unit): Unit = () def exists[E <: AnyRef](p: E => Boolean): Boolean = false - def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty + override def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty def /: [A, E <: AnyRef](z: A)(f: (A, E) => A): A = z def toList = Nil def iterator = Iterator.empty @@ -79,7 +82,7 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = f(x0.asInstanceOf[Elem]) def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) - def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + override def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = Set1(f(x0.asInstanceOf[Elem])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(z, x0.asInstanceOf[E]) @@ -99,8 +102,10 @@ object SimpleIdentitySet { def foreach(f: Elem => Unit): Unit = { f(x0.asInstanceOf[Elem]); f(x1.asInstanceOf[Elem]) } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) - def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = - Set2(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem])) + override def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + val y0 = f(x0.asInstanceOf[Elem]) + val y1 = f(x1.asInstanceOf[Elem]) + if y0 eq y1 then Set1(y0) else Set2(y0, y1) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: Nil @@ -133,8 +138,12 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = p(x0.asInstanceOf[E]) || p(x1.asInstanceOf[E]) || p(x2.asInstanceOf[E]) - def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = - Set3(f(x0.asInstanceOf[Elem]), f(x1.asInstanceOf[Elem]), f(x2.asInstanceOf[Elem])) + override def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = + val y0 = f(x0.asInstanceOf[Elem]) + val y1 = f(x1.asInstanceOf[Elem]) + val y2 = f(x2.asInstanceOf[Elem]) + if (y0 ne y1) && (y0 ne y2) && (y1 ne y2) then Set3(y0, y1, y2) + else super.map(f) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = f(f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]), x2.asInstanceOf[E]) def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: x2.asInstanceOf[Elem] :: Nil @@ -182,8 +191,6 @@ object SimpleIdentitySet { } def exists[E >: Elem <: AnyRef](p: E => Boolean): Boolean = xs.asInstanceOf[Array[E]].exists(p) - def map[B <: AnyRef](f: Elem => B): SimpleIdentitySet[B] = - SetN(xs.map(x => f(x.asInstanceOf[Elem]).asInstanceOf[AnyRef])) def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A = xs.asInstanceOf[Array[E]].foldLeft(z)(f) def toList: List[Elem] = { diff --git a/tests/neg-custom-args/captures/capt1.check b/tests/neg-custom-args/captures/capt1.check index 85d3b2a7ddcb..6b4c50b69ae4 100644 --- a/tests/neg-custom-args/captures/capt1.check +++ b/tests/neg-custom-args/captures/capt1.check @@ -15,7 +15,7 @@ -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:14:2 ----------------------------------------- 14 | def f(y: Int) = if x == null then y else y // error | ^ - | Found: Int ->{x} Int + | Found: (y: Int) ->{x} Int | Required: Matchable 15 | f | diff --git a/tests/neg-custom-args/captures/heal-tparam-cs.scala b/tests/neg-custom-args/captures/heal-tparam-cs.scala index 58d12f8b6ce5..c0fa29bcca3b 100644 --- a/tests/neg-custom-args/captures/heal-tparam-cs.scala +++ b/tests/neg-custom-args/captures/heal-tparam-cs.scala @@ -2,32 +2,33 @@ import language.experimental.captureChecking trait Cap { def use(): Unit } -def localCap[sealed T](op: (cap: Cap^{cap}) => T): T = ??? +def localCap[sealed T](op: (c: Cap^{cap}) => T): T = ??? def main(io: Cap^{cap}, net: Cap^{cap}): Unit = { - val test1 = localCap { cap => // error - () => { cap.use() } + + val test1 = localCap { c => // error + () => { c.use() } } - val test2: (cap: Cap^{cap}) -> () ->{cap} Unit = - localCap { cap => // should work - (cap1: Cap^{cap}) => () => { cap1.use() } + val test2: (c: Cap^{cap}) -> () ->{cap} Unit = + localCap { c => // should work + (c1: Cap^{cap}) => () => { c1.use() } } - val test3: (cap: Cap^{io}) -> () ->{io} Unit = - localCap { cap => // should work - (cap1: Cap^{io}) => () => { cap1.use() } + val test3: (c: Cap^{io}) -> () ->{io} Unit = + localCap { c => // should work + (c1: Cap^{io}) => () => { c1.use() } } - val test4: (cap: Cap^{io}) -> () ->{net} Unit = - localCap { cap => // error - (cap1: Cap^{io}) => () => { cap1.use() } + val test4: (c: Cap^{io}) -> () ->{net} Unit = + localCap { c => // error + (c1: Cap^{io}) => () => { c1.use() } } - def localCap2[sealed T](op: (cap: Cap^{io}) => T): T = ??? + def localCap2[sealed T](op: (c: Cap^{io}) => T): T = ??? val test5: () ->{io} Unit = - localCap2 { cap => // ok - () => { cap.use() } + localCap2 { c => // ok + () => { c.use() } } } diff --git a/tests/neg-custom-args/captures/refs.scala b/tests/neg-custom-args/captures/refs.scala new file mode 100644 index 000000000000..df38027a5643 --- /dev/null +++ b/tests/neg-custom-args/captures/refs.scala @@ -0,0 +1,42 @@ +import java.io.* + +class Ref[T](init: T): + var x: T = init + def setX(x: T): Unit = this.x = x + +def usingLogFile[sealed T](op: FileOutputStream^ => T): T = + val logFile = FileOutputStream("log") + val result = op(logFile) + logFile.close() + result + +type Proc = () => Unit +def test1 = + usingLogFile[Proc]: f => // error + () => + f.write(1) + () + +def test2 = + val r = new Ref[Proc](() => ()) + usingLogFile[Unit]: f => + r.setX(() => f.write(10)) // should be error + r.x() // crash: f is closed at that point + +def test3 = + val r = new Ref[Proc](() => ()) + usingLogFile[Unit]: f => + r.x = () => f.write(10) // should be error + r.x() // crash: f is closed at that point + +def test4 = + var r: Proc = () => () // error + usingLogFile[Unit]: f => + r = () => f.write(10) + r() // crash: f is closed at that point + + + + + + diff --git a/tests/neg-custom-args/captures/try.check b/tests/neg-custom-args/captures/try.check index 9afbe61d2280..c9b7910ad534 100644 --- a/tests/neg-custom-args/captures/try.check +++ b/tests/neg-custom-args/captures/try.check @@ -6,7 +6,7 @@ | This is often caused by a local capability in an argument of method handle | leaking as part of its result. -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:29:43 ------------------------------------------ -29 | val b = handle[Exception, () -> Nothing] { // error +29 | val b = handle[Exception, () -> Nothing] { // error | ^ | Found: (x: CT[Exception]^) ->? () ->{x} Nothing | Required: (x$0: CanThrow[Exception]) => () -> Nothing diff --git a/tests/neg-custom-args/captures/try.scala b/tests/neg-custom-args/captures/try.scala index 3c6f0605d8b9..55e065de9f9f 100644 --- a/tests/neg-custom-args/captures/try.scala +++ b/tests/neg-custom-args/captures/try.scala @@ -26,7 +26,7 @@ def test = (ex: Exception) => ??? } - val b = handle[Exception, () -> Nothing] { // error + val b = handle[Exception, () -> Nothing] { // error (x: CanThrow[Exception]) => () => raise(new Exception)(using x) } { (ex: Exception) => ??? diff --git a/tests/neg-custom-args/captures/usingLogFile-alt.check b/tests/neg-custom-args/captures/usingLogFile-alt.check index 9444bc9dc46a..7da4ba941b65 100644 --- a/tests/neg-custom-args/captures/usingLogFile-alt.check +++ b/tests/neg-custom-args/captures/usingLogFile-alt.check @@ -1,7 +1,7 @@ -- Error: tests/neg-custom-args/captures/usingLogFile-alt.scala:18:2 --------------------------------------------------- 18 | usingFile( // error | ^^^^^^^^^ - | Sealed type variable T cannot be instantiated to box () => Unit since - | that type captures the root capability `cap`. - | This is often caused by a local capability in an argument of method usingFile - | leaking as part of its result. + | reference (file : java.io.OutputStream^{}) is not included in allowed capture set {x$0} + | + | Note that reference (file : java.io.OutputStream^{}), defined at level 1 + | cannot be included in outer capture set {x$0}, defined at level 0 in package diff --git a/tests/neg-custom-args/captures/usingLogFile.check b/tests/neg-custom-args/captures/usingLogFile.check index ff4c9fd3105f..d67c59b7c512 100644 --- a/tests/neg-custom-args/captures/usingLogFile.check +++ b/tests/neg-custom-args/captures/usingLogFile.check @@ -41,7 +41,7 @@ -- Error: tests/neg-custom-args/captures/usingLogFile.scala:71:16 ------------------------------------------------------ 71 | val later = usingFile("logfile", // error | ^^^^^^^^^ - | Sealed type variable T cannot be instantiated to box () => Unit since - | that type captures the root capability `cap`. - | This is often caused by a local capability in an argument of method usingFile - | leaking as part of its result. + | reference (_$1 : java.io.OutputStream^{}) is not included in allowed capture set {x$0} + | + | Note that reference (_$1 : java.io.OutputStream^{}), defined at level 2 + | cannot be included in outer capture set {x$0}, defined at level 1 in method test diff --git a/tests/pos-custom-args/captures/bynamefun.scala b/tests/pos-custom-args/captures/bynamefun.scala index 86bad201ffc3..414f0c46c42f 100644 --- a/tests/pos-custom-args/captures/bynamefun.scala +++ b/tests/pos-custom-args/captures/bynamefun.scala @@ -1,11 +1,14 @@ object test: class Plan(elem: Plan) object SomePlan extends Plan(???) + type PP = (-> Plan) -> Plan def f1(expr: (-> Plan) -> Plan): Plan = expr(SomePlan) f1 { onf => Plan(onf) } def f2(expr: (=> Plan) -> Plan): Plan = ??? f2 { onf => Plan(onf) } def f3(expr: (-> Plan) => Plan): Plan = ??? - f1 { onf => Plan(onf) } + f3 { onf => Plan(onf) } def f4(expr: (=> Plan) => Plan): Plan = ??? - f2 { onf => Plan(onf) } + f4 { onf => Plan(onf) } + def f5(expr: PP): Plan = expr(SomePlan) + f5 { onf => Plan(onf) } \ No newline at end of file