Skip to content

Commit 62b2f1c

Browse files
committed
Follow upper bounds of type variables when computing dcs
1 parent 8f8a15f commit 62b2f1c

File tree

9 files changed

+79
-2
lines changed

9 files changed

+79
-2
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,13 +1120,17 @@ object CaptureSet:
11201120
*/
11211121
def ofTypeDeeply(tp: Type)(using Context): CaptureSet =
11221122
val collect = new TypeAccumulator[CaptureSet]:
1123+
val seen = util.HashSet[Symbol]()
11231124
def apply(cs: CaptureSet, t: Type) =
11241125
if variance <= 0 then cs
11251126
else t.dealias match
11261127
case t @ CapturingType(p, cs1) =>
11271128
this(cs, p) ++ cs1
11281129
case t @ AnnotatedType(parent, ann) =>
11291130
this(cs, parent)
1131+
case t: TypeRef if t.symbol.isAbstractOrParamType && !seen.contains(t.symbol) =>
1132+
seen += t.symbol
1133+
this(cs, t.info.bounds.hi)
11301134
case t @ FunctionOrMethod(args, res @ Existential(_, _))
11311135
if args.forall(_.isAlwaysPure) =>
11321136
this(cs, Existential.toCap(res))

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ class CheckCaptures extends Recheck, SymTransformer:
660660
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
661661
val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol
662662
if !getter.is(Private) && getter.hasTrackedParts then
663-
refined = RefinedType(refined, getterName, argType)
663+
refined = RefinedType(refined, getterName, argType.unboxed)
664664
allCaptures ++= argType.captureSet
665665
(refined, allCaptures)
666666

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,15 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
8181
private def newFlagsFor(symd: SymDenotation)(using Context): FlagSet =
8282

8383
object containsCovarRetains extends TypeAccumulator[Boolean]:
84+
val seen = util.HashSet[Symbol]()
8485
def apply(x: Boolean, tp: Type): Boolean =
8586
if x then true
8687
else if tp.derivesFromCapability && variance >= 0 then true
8788
else tp match
8889
case AnnotatedType(_, ann) if ann.symbol.isRetains && variance >= 0 => true
90+
case t: TypeRef if t.symbol.isAbstractOrParamType && !seen.contains(t.symbol) =>
91+
seen += t.symbol
92+
apply(x, t.info.bounds.hi)
8993
case _ => foldOver(x, tp)
9094
def apply(tp: Type): Boolean = apply(false, tp)
9195

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- Error: tests/neg-custom-args/captures/dcs-tvar.scala:6:15 -----------------------------------------------------------
2+
6 | () => runOps(xs) // error
3+
| ^^
4+
| reference xs* is not included in the allowed capture set {}
5+
| of an enclosing function literal with expected type () -> Unit
6+
-- Error: tests/neg-custom-args/captures/dcs-tvar.scala:9:15 -----------------------------------------------------------
7+
9 | () => runOps(xs) // error
8+
| ^^
9+
| reference xs* is not included in the allowed capture set {}
10+
| of an enclosing function literal with expected type () -> Unit
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import caps.use
2+
3+
def runOps(@use xs: List[() => Unit]): Unit = ???
4+
5+
def f[T <: List[() => Unit]](xs: T): () -> Unit =
6+
() => runOps(xs) // error
7+
8+
def g[T <: List[U], U <: () => Unit](xs: T): () -> Unit =
9+
() => runOps(xs) // error

tests/neg-custom-args/captures/i21646.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ class Resource[T <: Capability](gen: T):
99

1010
@main def run =
1111
val myFile: File = ???
12-
val r = Resource(myFile) // error
12+
val r = Resource(myFile) // now ok, was error
1313
()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
class IO
2+
3+
def f(xs: List[() => Unit]): () => Unit = () =>
4+
println(xs.head) // error
5+
6+
def test(io: IO^)(ys: List[() ->{io} Unit]) =
7+
val x = () =>
8+
val z = f(ys)
9+
z()
10+
val _: () -> Unit = x // !!! ys* gets lost
11+
()
12+
13+
def test(io: IO^) =
14+
def ys: List[() ->{io} Unit] = ???
15+
val x = () =>
16+
val z = f(ys)
17+
z()
18+
val _: () -> Unit = x // !!! io gets lost
19+
()
20+
21+
22+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import caps.use
2+
3+
def test(io: Object^, async: Object^) =
4+
5+
trait A:
6+
def f(@use x: List[() ->{io} Unit]): Unit
7+
8+
class B extends A:
9+
def f(@use x: List[() => Unit]): Unit = // error, would be unsound if allowed
10+
x.foreach(_())
11+
12+
class C extends A:
13+
def f(@use x: List[() ->{io, async} Unit]): Unit = // error, this one could be soundly allowed actually
14+
x.foreach(_())
15+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import language.experimental.captureChecking
2+
import caps.Capability
3+
4+
trait File extends Capability
5+
6+
class Resource[T <: Capability](gen: T):
7+
def use[U](f: T => U): U =
8+
f(gen) // OK, was error under unsealed
9+
10+
@main def run =
11+
val myFile: File = ???
12+
val r = Resource(myFile) // error
13+
()

0 commit comments

Comments
 (0)