Skip to content

Commit 6c949a9

Browse files
committed
Special treatment of types of function members
Members andThen, compose, curried, tupled of function types are now given special types for capture checking that reflect fine-grained capture dependencies.
1 parent 1a30250 commit 6c949a9

File tree

4 files changed

+78
-7
lines changed

4 files changed

+78
-7
lines changed

compiler/src/dotty/tools/dotc/cc/Synthetics.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import NameKinds.DefaultGetterName
1010
import Phases.checkCapturesPhase
1111
import config.Printers.capt
1212

13-
/** Classification and transformation methods for synthetic
14-
* case class methods that need to be treated specially.
13+
/** Classification and transformation methods for function methods and
14+
* synthetic case class methods that need to be treated specially.
1515
* In particular, compute capturing types for some of these methods which
1616
* have inferred (result-)types that need to be established under separate
1717
* compilation.
@@ -27,6 +27,9 @@ object Synthetics:
2727
case DefaultGetterName(nme.copy, _) => sym.is(Synthetic) && sym.owner.isClass && sym.owner.is(Case)
2828
case _ => false
2929

30+
private val functionCombinatorNames = Set[Name](
31+
nme.andThen, nme.compose, nme.curried, nme.tupled)
32+
3033
/** Is `sym` a synthetic apply, copy, or copy default getter method?
3134
* The types of these symbols are transformed in a special way without
3235
* looking at the definitions's RHS
@@ -37,6 +40,7 @@ object Synthetics:
3740
|| isSyntheticCopyDefaultGetterMethod(symd)
3841
|| (symd.symbol eq defn.Object_eq)
3942
|| (symd.symbol eq defn.Object_ne)
43+
|| defn.isFunctionClass(symd.owner) && functionCombinatorNames.contains(symd.name)
4044

4145
/** Method is excluded from regular capture checking.
4246
* Excluded are synthetic class members
@@ -156,6 +160,37 @@ object Synthetics:
156160
case info: PolyType =>
157161
info.derivedLambdaType(resType = dropUnapplyCaptures(info.resType))
158162

163+
private def transformComposeCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type =
164+
val (pt: PolyType) = symd.info: @unchecked
165+
val (mt: MethodType) = pt.resType: @unchecked
166+
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
167+
val mt1 =
168+
if toCC then
169+
MethodType(mt.paramNames)(
170+
mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)),
171+
mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head)))
172+
else
173+
MethodType(mt.paramNames)(
174+
mt1 => mt.paramInfos.map(_.stripCapturing),
175+
mt1 => mt.resType.stripCapturing)
176+
pt.derivedLambdaType(resType = mt1)
177+
178+
def transformCurriedTupledCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type =
179+
val (et: ExprType) = symd.info: @unchecked
180+
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
181+
def mapFinalResult(tp: Type, f: Type => Type): Type =
182+
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
183+
if defn.isFunctionNType(res) then
184+
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
185+
else
186+
f(tp)
187+
val resType1 =
188+
if toCC then
189+
mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis)))
190+
else
191+
et.resType.stripCapturing
192+
ExprType(resType1)
193+
159194
/** If `sym` refers to a synthetic apply, unapply, copy, or copy default getter method
160195
* of a case class, transform it to account for capture information.
161196
* The method is run in phase CheckCaptures.Pre
@@ -168,6 +203,10 @@ object Synthetics:
168203
sym.copySymDenotation(info = addUnapplyCaptures(sym.info))
169204
case nme.apply | nme.copy =>
170205
sym.copySymDenotation(info = addCaptureDeps(sym.info))
206+
case nme.andThen | nme.compose =>
207+
sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = true))
208+
case nme.curried | nme.tupled =>
209+
sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = true))
171210
case n if n == nme.eq || n == nme.ne =>
172211
sym.copySymDenotation(info =
173212
MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType))
@@ -183,6 +222,10 @@ object Synthetics:
183222
sym.copySymDenotation(info = dropUnapplyCaptures(sym.info))
184223
case nme.apply | nme.copy =>
185224
sym.copySymDenotation(info = dropCaptureDeps(sym.info))
225+
case nme.andThen | nme.compose =>
226+
sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = false))
227+
case nme.curried | nme.tupled =>
228+
sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = false))
186229
case n if n == nme.eq || n == nme.ne =>
187230
sym.copySymDenotation(info = defn.methOfAnyRef(defn.BooleanType))
188231

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ object StdNames {
395395
val UNIT : N = "UNIT"
396396
val acc: N = "acc"
397397
val adhocExtensions: N = "adhocExtensions"
398+
val andThen: N = "andThen"
398399
val annotation: N = "annotation"
399400
val any2stringadd: N = "any2stringadd"
400401
val anyHash: N = "anyHash"
@@ -443,11 +444,13 @@ object StdNames {
443444
val command: N = "command"
444445
val common: N = "common"
445446
val compiletime : N = "compiletime"
447+
val compose: N = "compose"
446448
val conforms_ : N = "$conforms"
447449
val contents: N = "contents"
448450
val copy: N = "copy"
449-
val currentMirror: N = "currentMirror"
450451
val create: N = "create"
452+
val currentMirror: N = "currentMirror"
453+
val curried: N = "curried"
451454
val definitions: N = "definitions"
452455
val delayedInit: N = "delayedInit"
453456
val delayedInitArg: N = "delayedInit$body"
@@ -622,6 +625,7 @@ object StdNames {
622625
val transparent : N = "transparent"
623626
val tree : N = "tree"
624627
val true_ : N = "true"
628+
val tupled: N = "tupled"
625629
val typedProductIterator: N = "typedProductIterator"
626630
val typeTagToManifest: N = "typeTagToManifest"
627631
val unapply: N = "unapply"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
class ContextClass
2+
type Context = ContextClass^
3+
4+
def Test(using ctx1: Context, ctx2: Context) =
5+
val f: Int => Int = identity
6+
val g1: Int ->{ctx1} Int = identity
7+
val g2: Int ->{ctx2} Int = identity
8+
val h: Int -> Int = identity
9+
val a1 = f.andThen(f); val _: Int ->{f} Int = a1
10+
val a2 = f.andThen(g1); val _: Int ->{f, g1} Int = a2
11+
val a3 = f.andThen(g2); val _: Int ->{f, g2} Int = a3
12+
val a4 = f.andThen(h); val _: Int ->{f} Int = a4
13+
val b1 = g1.andThen(f); val _: Int ->{f, g1} Int = b1
14+
val b2 = g1.andThen(g1); val _: Int ->{g1} Int = b2
15+
val b3 = g1.andThen(g2); val _: Int ->{g1, g2} Int = b3
16+
val b4 = g1.andThen(h); val _: Int ->{g1} Int = b4
17+
val c1 = h.andThen(f); val _: Int ->{f} Int = c1
18+
val c2 = h.andThen(g1); val _: Int ->{g1} Int = c2
19+
val c3 = h.andThen(g2); val _: Int ->{g2} Int = c3
20+
val c4 = h.andThen(h); val _: Int -> Int = c4
21+
22+
val f2: (Int, Int) => Int = _ + _
23+
val f2c = f2.curried; val _: Int -> Int ->{f2} Int = f2c
24+
val f2t = f2.tupled; val _: ((Int, Int)) ->{f2} Int = f2t
25+
26+
val f3: (Int, Int, Int) => Int = ???
27+
val f3c = f3.curried; val _: Int -> Int -> Int ->{f3} Int = f3c
28+
val f3t = f3.tupled; val _: ((Int, Int, Int)) ->{f3} Int = f3t

tests/pos-custom-args/captures/inlined-closure.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,3 @@ def Test(using ctx: Context) =
1212
info.paramRefs.filter(_.isTracked)
1313
val p = atPhase()((_: ParamRef).isTracked)
1414
val _: ParamRef ->{ctx} Boolean = p
15-
16-
//val f: String => ParamRef = ???
17-
//val q = f.andThen(p)
18-

0 commit comments

Comments
 (0)