Skip to content

Commit 8cbc022

Browse files
committed
Turn on separation checking for applications
- Use unsafeAssumeSeparate(...) as an escape hatch
1 parent 3a26fe8 commit 8cbc022

File tree

24 files changed

+254
-66
lines changed

24 files changed

+254
-66
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ object CheckCaptures:
239239

240240
/** Was a new type installed for this tree? */
241241
def hasNuType: Boolean
242+
243+
/** Is this tree passed to a parameter or assigned to a value with a type
244+
* that contains cap in no-flip covariant position, which will necessite
245+
* a separation check?
246+
*/
247+
def needsSepCheck: Boolean
242248
end CheckerAPI
243249

244250
class CheckCaptures extends Recheck, SymTransformer:
@@ -279,6 +285,12 @@ class CheckCaptures extends Recheck, SymTransformer:
279285
*/
280286
private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]
281287

288+
/** Trees that will need a separation check because they contain cap */
289+
private val sepCheckable = util.EqHashSet[Tree]()
290+
291+
extension [T <: Tree](tree: T)
292+
def needsSepCheck: Boolean = sepCheckable.contains(tree)
293+
282294
/** Instantiate capture set variables appearing contra-variantly to their
283295
* upper approximation.
284296
*/
@@ -636,11 +648,11 @@ class CheckCaptures extends Recheck, SymTransformer:
636648
val meth = tree.fun.symbol
637649
if meth == defn.Caps_unsafeAssumePure then
638650
val arg :: Nil = tree.args: @unchecked
639-
val argType0 = recheck(arg, pt.capturing(CaptureSet.universal))
651+
val argType0 = recheck(arg, pt.stripCapturing.capturing(CaptureSet.universal))
640652
val argType =
641653
if argType0.captureSet.isAlwaysEmpty then argType0
642654
else argType0.widen.stripCapturing
643-
capt.println(i"rechecking $arg with $pt: $argType")
655+
capt.println(i"rechecking unsafeAssumePure of $arg with $pt: $argType")
644656
super.recheckFinish(argType, tree, pt)
645657
else
646658
val res = super.recheckApply(tree, pt)
@@ -660,6 +672,9 @@ class CheckCaptures extends Recheck, SymTransformer:
660672
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
661673
markFree(argType.deepCaptureSet, arg.srcPos)
662674
case _ =>
675+
if formal.containsCap then
676+
arg.updNuType(freshenedFormal)
677+
sepCheckable += arg
663678
argType
664679

665680
/** Map existential captures in result to `cap` and implement the following
@@ -1785,6 +1800,7 @@ class CheckCaptures extends Recheck, SymTransformer:
17851800
end checker
17861801

17871802
checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
1803+
if ccConfig.useFresh then SepChecker(this).traverse(unit)
17881804
if !ctx.reporter.errorsReported then
17891805
// We dont report errors here if previous errors were reported, because other
17901806
// errors often result in bad applied types, but flagging these bad types gives
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
import ast.tpd
5+
import collection.mutable
6+
7+
import core.*
8+
import Symbols.*, Types.*
9+
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
10+
import CaptureSet.{Refs, emptySet}
11+
import config.Printers.capt
12+
import StdNames.nme
13+
14+
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
15+
import tpd.*
16+
import checker.*
17+
18+
extension (cs: CaptureSet)
19+
def footprint(using Context): CaptureSet =
20+
def recur(elems: CaptureSet.Refs, newElems: List[CaptureRef]): CaptureSet.Refs = newElems match
21+
case newElem :: newElems1 =>
22+
val superElems = newElem.captureSetOfInfo.elems.filter: superElem =>
23+
!superElem.isMaxCapability && !elems.contains(superElem)
24+
recur(superElems ++ elems, superElems.toList ++ newElems1)
25+
case Nil => elems
26+
val elems: CaptureSet.Refs = cs.elems.filter(!_.isMaxCapability)
27+
CaptureSet(recur(elems, elems.toList))
28+
29+
def overlapWith(other: CaptureSet)(using Context): CaptureSet.Refs =
30+
val refs1 = cs.elems
31+
val refs2 = other.elems
32+
def common(refs1: CaptureSet.Refs, refs2: CaptureSet.Refs) =
33+
refs1.filter: ref =>
34+
ref.isExclusive && refs2.exists(_.stripReadOnly eq ref)
35+
common(refs1, refs2) ++ common(refs2, refs1)
36+
37+
private def hidden(elem: CaptureRef)(using Context): CaptureSet.Refs = elem match
38+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ hidden(hcs)
39+
case ReadOnlyCapability(ref) => hidden(ref).map(_.readOnly)
40+
case _ => emptySet
41+
42+
private def hidden(cs: CaptureSet)(using Context): CaptureSet.Refs =
43+
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet
44+
45+
def hiddenByElem(elem: CaptureRef): CaptureSet.Refs =
46+
if seen.add(elem) then elem match
47+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs)
48+
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
49+
case _ => emptySet
50+
else emptySet
51+
52+
def recur(cs: CaptureSet): CaptureSet.Refs =
53+
(emptySet /: cs.elems): (elems, elem) =>
54+
elems ++ hiddenByElem(elem)
55+
56+
recur(cs)
57+
end hidden
58+
59+
private def checkApply(fn: Tree, args: List[Tree])(using Context): Unit =
60+
val fnCaptures = fn.nuType.deepCaptureSet
61+
62+
def captures(arg: Tree) =
63+
val argType = arg.nuType
64+
argType match
65+
case AnnotatedType(formal1, ann) if ann.symbol == defn.UseAnnot =>
66+
argType.deepCaptureSet
67+
case _ =>
68+
argType.captureSet
69+
70+
val argCaptures = args.map(captures)
71+
capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = $argCaptures")
72+
var footprint = argCaptures.foldLeft(fnCaptures.footprint): (fp, ac) =>
73+
fp ++ ac.footprint
74+
val paramNames = fn.nuType.widen match
75+
case MethodType(pnames) => pnames
76+
case _ => args.indices.map(nme.syntheticParamName(_))
77+
for (arg, ac, pname) <- args.lazyZip(argCaptures).lazyZip(paramNames) do
78+
if arg.needsSepCheck then
79+
val hiddenInArg = CaptureSet(hidden(ac))
80+
//println(i"check sep $arg / $footprint / $hiddenInArg")
81+
val overlap = hiddenInArg.footprint.overlapWith(footprint)
82+
if !overlap.isEmpty then
83+
def whatStr = if overlap.size == 1 then "this capability" else "these capabilities"
84+
def funStr =
85+
if fn.symbol.exists then i"${fn.symbol}"
86+
else "the function"
87+
report.error(
88+
em"""Separation failure: argument to capture-polymorphic parameter $pname: ${arg.nuType}
89+
|captures ${CaptureSet(overlap)} and also passes $whatStr separately to $funStr""",
90+
arg.srcPos)
91+
footprint ++= hiddenInArg
92+
93+
private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match
94+
case Apply(fn, args) => traverseApply(fn, args :: argss)
95+
case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments
96+
case _ =>
97+
if argss.nestedExists(_.needsSepCheck) then
98+
checkApply(tree, argss.flatten)
99+
100+
def traverse(tree: Tree)(using Context): Unit =
101+
tree match
102+
case tree: GenericApply =>
103+
if tree.symbol != defn.Caps_unsafeAssumeSeparate then
104+
tree.tpe match
105+
case _: MethodOrPoly =>
106+
case _ => traverseApply(tree, Nil)
107+
traverseChildren(tree)
108+
case _ =>
109+
traverseChildren(tree)
110+
end SepChecker
111+
112+
113+
114+
115+
116+

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ class Definitions {
10001000
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
10011001
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
10021002
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
1003+
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
10031004
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
10041005
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
10051006
@tu lazy val Caps_Mutable: ClassSymbol = requiredClass("scala.caps.Mutable")

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4178,7 +4178,7 @@ object Types extends TypeUtils {
41784178
tl => params.map(p => tl.integrate(params, adaptParamInfo(p))),
41794179
tl => tl.integrate(params, resultType))
41804180

4181-
/** Adapt info of parameter symbol to be integhrated into corresponding MethodType
4181+
/** Adapt info of parameter symbol to be integrated into corresponding MethodType
41824182
* using the scheme described in `fromSymbols`.
41834183
*/
41844184
def adaptParamInfo(param: Symbol, pinfo: Type)(using Context): Type =

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ abstract class Recheck extends Phase, SymTransformer:
167167
* from the current type.
168168
*/
169169
def setNuType(tpe: Type): Unit =
170-
if nuTypes.lookup(tree) == null && (tpe ne tree.tpe) then nuTypes(tree) = tpe
170+
if nuTypes.lookup(tree) == null then updNuType(tpe)
171+
172+
/** Set new type of the tree unconditionally. */
173+
def updNuType(tpe: Type): Unit =
174+
if tpe ne tree.tpe then nuTypes(tree) = tpe
171175

172176
/** The new type of the tree, or if none was installed, the original type */
173177
def nuType(using Context): Type =

library/src/scala/caps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,9 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
7979
*/
8080
def unsafeAssumePure: T = x
8181

82+
/** A wrapper around code for which separation checks are suppressed.
83+
*/
84+
def unsafeAssumeSeparate[T](op: T): T = op
85+
8286
end unsafe
87+
end caps

scala2-library-cc/src/scala/collection/IndexedSeqView.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ object IndexedSeqView {
136136

137137
@SerialVersionUID(3L)
138138
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)
139-
extends SeqView.Concat[A](prefix, suffix) with IndexedSeqView[A]
139+
extends SeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with IndexedSeqView[A]
140140

141141
@SerialVersionUID(3L)
142142
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)

scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
682682
remaining -= 1
683683
scout = scout.tail
684684
}
685-
dropRightState(scout)
685+
caps.unsafe.unsafeAssumeSeparate(dropRightState(scout))
686686
}
687687
}
688688

@@ -879,6 +879,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
879879
if (!cursor.stateDefined) b.append(sep).append("<not computed>")
880880
} else {
881881
@inline def same(a: LazyListIterable[A]^, b: LazyListIterable[A]^): Boolean = (a eq b) || (a.state eq b.state)
882+
// !!!CC with qualifiers, same should have cap.rd parameters
882883
// Cycle.
883884
// If we have a prefix of length P followed by a cycle of length C,
884885
// the scout will be at position (P%C) in the cycle when the cursor
@@ -890,7 +891,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
890891
// the start of the loop.
891892
var runner = this
892893
var k = 0
893-
while (!same(runner, scout)) {
894+
while (!caps.unsafe.unsafeAssumeSeparate(same(runner, scout))) {
894895
runner = runner.tail
895896
scout = scout.tail
896897
k += 1
@@ -900,11 +901,11 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
900901
// everything once. If cursor is already at beginning, we'd better
901902
// advance one first unless runner didn't go anywhere (in which case
902903
// we've already looped once).
903-
if (same(cursor, scout) && (k > 0)) {
904+
if (caps.unsafe.unsafeAssumeSeparate(same(cursor, scout)) && (k > 0)) {
904905
appendCursorElement()
905906
cursor = cursor.tail
906907
}
907-
while (!same(cursor, scout)) {
908+
while (!caps.unsafe.unsafeAssumeSeparate(same(cursor, scout))) {
908909
appendCursorElement()
909910
cursor = cursor.tail
910911
}
@@ -1052,7 +1053,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10521053
val head = it.next()
10531054
rest = rest.tail
10541055
restRef = rest // restRef.elem = rest
1055-
sCons(head, newLL(stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state)))
1056+
sCons(head, newLL(
1057+
caps.unsafe.unsafeAssumeSeparate(
1058+
stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state))))
10561059
} else State.Empty
10571060
}
10581061
}

scala2-library-cc/src/scala/collection/mutable/CheckedIndexedSeqView.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ private[mutable] object CheckedIndexedSeqView {
7575

7676
@SerialVersionUID(3L)
7777
class Concat[A](prefix: SomeIndexedSeqOps[A]^, suffix: SomeIndexedSeqOps[A]^)(protected val mutationCount: () => Int)
78-
extends IndexedSeqView.Concat[A](prefix, suffix) with CheckedIndexedSeqView[A]
78+
extends IndexedSeqView.Concat[A](prefix, caps.unsafe.unsafeAssumeSeparate(suffix)) with CheckedIndexedSeqView[A]
7979

8080
@SerialVersionUID(3L)
8181
class Take[A](underlying: SomeIndexedSeqOps[A]^, n: Int)(protected val mutationCount: () => Int)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-- Error: tests/neg-custom-args/captures/cc-dep-param.scala:8:6 --------------------------------------------------------
2+
8 | foo(a, useA) // error: separation failure
3+
| ^
4+
| Separation failure: argument to capture-polymorphic parameter x$0: Foo[Int]^
5+
| captures {a} and also passes this capability separately to method foo

0 commit comments

Comments
 (0)