Skip to content

Commit 25104d1

Browse files
committed
Improve fluidify
- Also add Fluid to function types - Add Fluid recursively in arguments of fluidified types - But stop when a capturing type is encountered Also, fix needsVariable for FromJavaObject. FromJavaObject should behave like Any (i.e. no capture set variable needs to be added).
1 parent 12735bd commit 25104d1

File tree

5 files changed

+31
-24
lines changed

5 files changed

+31
-24
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,20 +291,23 @@ class CheckCaptures extends Recheck, SymTransformer:
291291
if sym.exists && curEnv.isOpen then markFree(capturedVars(sym), pos)
292292

293293
private def handleBackwardsCompat(tp: Type, sym: Symbol, initialVariance: Int = 1)(using Context): Type =
294-
val fluidify = new TypeMap:
294+
val fluidify = new TypeMap with IdempotentCaptRefMap:
295295
variance = initialVariance
296296
def apply(t: Type): Type = t match
297-
case tp: MethodType =>
298-
mapOver(tp)
299-
case tp: TypeLambda =>
300-
tp.derivedLambdaType(resType = this(tp.resType))
301-
case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) =>
302-
tp.derivedRefinedType(parent, rname, this(rinfo))
303-
case tp @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp) =>
304-
mapOver(tp)
297+
case t: MethodType =>
298+
mapOver(t)
299+
case t: TypeLambda =>
300+
t.derivedLambdaType(resType = this(t.resType))
301+
case CapturingType(_, _) =>
302+
t
305303
case _ =>
306-
if variance > 0 then t
307-
else Setup.decorate(t, Function.const(CaptureSet.Fluid))
304+
val t1 = t match
305+
case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) =>
306+
t.derivedRefinedType(parent, rname, this(rinfo))
307+
case _ =>
308+
mapOver(t)
309+
if variance > 0 then t1
310+
else Setup.decorate(t1, Function.const(CaptureSet.Fluid))
308311

309312
def isPreCC(sym: Symbol): Boolean =
310313
sym.isTerm && sym.maybeOwner.isClass

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,17 @@ object Setup:
418418
def needsVariable(tp: Type)(using Context): Boolean = {
419419
tp.typeParams.isEmpty && tp.match
420420
case tp: (TypeRef | AppliedType) =>
421-
val tp1 = tp.dealias
422-
if tp1 ne tp then needsVariable(tp1)
421+
val sym = tp.typeSymbol
422+
if sym.isClass then
423+
!sym.isPureClass && sym != defn.AnyClass
423424
else
424-
val sym = tp1.typeSymbol
425-
if sym.isClass then
426-
!sym.isPureClass
427-
&& sym != defn.AnyClass
428-
&& sym != defn.FromJavaObjectSymbol
429-
else superTypeIsImpure(tp1)
425+
sym != defn.FromJavaObjectSymbol
426+
// For capture checking, we assume Object from Java is the same as Any
427+
&& {
428+
val tp1 = tp.dealias
429+
if tp1 ne tp then needsVariable(tp1)
430+
else superTypeIsImpure(tp1)
431+
}
430432
case tp: (RefinedOrRecType | MatchType) =>
431433
needsVariable(tp.underlying)
432434
case tp: AndType =>
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import caps.unsafe.*
21
def test =
32
val tasks = new collection.mutable.ArrayBuffer[() => Unit]
4-
val _: Unit = tasks.foreach(((task: () => Unit) => task()).unsafeBoxFunArg)
3+
val _: Unit = tasks.foreach(((task: () => Unit) => task()))

tests/pos/cc-backwards-compat/A.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
package p
2-
class A:
2+
class A(f: Int => Int):
3+
def foo(f: Int => Int) = ???
34
def map(other: Iter): Iter = other
45
def pair[T](x: T): (T, T) = (x, x)

tests/pos/cc-backwards-compat/Iter.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ class Iter:
55
self: Iter^ =>
66

77
def test(it: Iter^) =
8-
val a = A()
9-
//val b = a.map(it) // does not work yet
8+
val f: Int ->{it} Int = ???
9+
val a = new A(f)
10+
val b = a.map(it) // does not work yet
1011
val c = a.pair(it)
12+
val d = a.foo(f)

0 commit comments

Comments
 (0)