Skip to content

Commit dbe1831

Browse files
committed
Implement polymorphic lambdas using Closure nodes for efficiency
Previously, we desugared them manually into anonymous class instances, but by using a Closure node instead, we ensure that they get translated into indy lambdas on the JVM. Also cleaned up and added a TODO in the desugaring of polymorphic function types into refinement types since I realized that purity wasn't taken into account.
1 parent 9d08db1 commit dbe1831

File tree

4 files changed

+87
-77
lines changed

4 files changed

+87
-77
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,40 @@ object desugar {
10361036
name
10371037
}
10381038

1039+
/** Strip parens and empty blocks around the body of `tree`. */
1040+
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
1041+
def stripped(body: Tree): Tree = body match
1042+
case Parens(body1) =>
1043+
stripped(body1)
1044+
case Block(Nil, body1) =>
1045+
stripped(body1)
1046+
case _ => body
1047+
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
1048+
1049+
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1050+
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1051+
*/
1052+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1053+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1054+
val funFlags = fun match
1055+
case fun: FunctionWithMods =>
1056+
fun.mods.flags
1057+
case _ => EmptyFlags
1058+
1059+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1060+
// val isImpure = funFlags.is(Impure)
1061+
1062+
// Function flags to be propagated to each parameter in the desugared method type.
1063+
val paramFlags = funFlags.toTermFlags & Given
1064+
val vparams = vparamTypes.zipWithIndex.map:
1065+
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1066+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1067+
1068+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1069+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1070+
)).withSpan(tree.span)
1071+
end makePolyFunctionType
1072+
10391073
/** Invent a name for an anonympus given of type or template `impl`. */
10401074
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
10411075
val str = impl match
@@ -1429,14 +1463,17 @@ object desugar {
14291463
}
14301464

14311465
/** Make closure corresponding to function.
1432-
* params => body
1466+
* [tparams] => params => body
14331467
* ==>
1434-
* def $anonfun(params) = body
1468+
* def $anonfun[tparams](params) = body
14351469
* Closure($anonfun)
14361470
*/
1437-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1471+
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1472+
val paramss: List[ParamClause] =
1473+
if tparams.isEmpty then vparams :: Nil
1474+
else tparams :: vparams :: Nil
14381475
Block(
1439-
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
1476+
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
14401477
.withSpan(span)
14411478
.withMods(synthetic | Artifact),
14421479
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
@@ -1728,56 +1765,6 @@ object desugar {
17281765
}
17291766
}
17301767

1731-
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
1732-
case Parens(body1) =>
1733-
makePolyFunction(targs, body1, pt)
1734-
case Block(Nil, body1) =>
1735-
makePolyFunction(targs, body1, pt)
1736-
case Function(vargs, res) =>
1737-
assert(targs.nonEmpty)
1738-
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1739-
val mods = body match {
1740-
case body: FunctionWithMods => body.mods
1741-
case _ => untpd.EmptyModifiers
1742-
}
1743-
val polyFunctionTpt = ref(defn.PolyFunctionType)
1744-
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1745-
if (ctx.mode.is(Mode.Type)) {
1746-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1747-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1748-
1749-
val applyVParams = vargs.zipWithIndex.map {
1750-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1751-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
1752-
}
1753-
RefinedTypeTree(polyFunctionTpt, List(
1754-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
1755-
))
1756-
}
1757-
else {
1758-
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1759-
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1760-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1761-
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1762-
1763-
def typeTree(tp: Type) = tp match
1764-
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1765-
untpd.DependentTypeTree((tsyms, vsyms) =>
1766-
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1767-
case _ => TypeTree()
1768-
1769-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1770-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1771-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1772-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
1773-
))
1774-
}
1775-
case _ =>
1776-
// may happen for erroneous input. An error will already have been reported.
1777-
assert(ctx.reporter.errorsReported)
1778-
EmptyTree
1779-
}
1780-
17811768
// begin desugar
17821769

17831770
// Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1790,8 +1777,6 @@ object desugar {
17901777
}
17911778

17921779
val desugared = tree match {
1793-
case PolyFunction(targs, body) =>
1794-
makePolyFunction(targs, body, pt) orElse tree
17951780
case SymbolLit(str) =>
17961781
Apply(
17971782
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ object Types {
18721872
if alwaysDependent || mt.isResultDependent then
18731873
RefinedType(funType, nme.apply, mt)
18741874
else funType
1875+
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
1876+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
18751877
}
18761878

18771879
/** The signature of this type. This is by default NotAMethod,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,12 +1625,32 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16251625
)
16261626
cpy.ValDef(param)(tpt = paramTpt)
16271627
if isErased then param0.withAddedFlags(Flags.Erased) else param0
1628-
desugared = desugar.makeClosure(inferredParams, fnBody, resultTpt, tree.span)
1628+
desugared = desugar.makeClosure(Nil, inferredParams, fnBody, resultTpt, tree.span)
16291629

16301630
typed(desugared, pt)
16311631
.showing(i"desugared fun $tree --> $desugared with pt = $pt", typr)
16321632
}
16331633

1634+
1635+
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1636+
val tree1 = desugar.normalizePolyFunction(tree)
1637+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1638+
else typedPolyFunctionValue(tree1, pt)
1639+
1640+
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
1641+
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
1642+
val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked
1643+
1644+
val resultTpt = pt.dealias match
1645+
case RefinedType(parent, nme.apply, poly @ PolyType(_, mt: MethodType)) if parent.classSymbol eq defn.PolyFunctionClass =>
1646+
untpd.DependentTypeTree((tsyms, vsyms) =>
1647+
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
1648+
case _ => untpd.TypeTree()
1649+
1650+
val desugared = desugar.makeClosure(tparams, vparams, body, resultTpt, tree.span)
1651+
typed(desugared, pt)
1652+
end typedPolyFunctionValue
1653+
16341654
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
16351655
val env1 = tree.env mapconserve (typed(_))
16361656
val meth1 = typedUnadapted(tree.meth)
@@ -1668,6 +1688,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
16681688
else
16691689
EmptyTree
16701690
}
1691+
case _: PolyType =>
1692+
// Polymorphic SAMs are not currently supported (#6904).
1693+
EmptyTree
16711694
case tp =>
16721695
if !tp.isErroneous then
16731696
throw new java.lang.Error(i"internal error: closing over non-method $tp, pos = ${tree.span}")
@@ -2425,7 +2448,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24252448
case rhs => typedExpr(rhs, tpt1.tpe.widenExpr)
24262449
}
24272450
val vdef1 = assignType(cpy.ValDef(vdef)(name, tpt1, rhs1), sym)
2428-
postProcessInfo(sym)
2451+
postProcessInfo(vdef1, sym)
24292452
vdef1.setDefTree
24302453
}
24312454

@@ -2534,19 +2557,31 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
25342557

25352558
val ddef2 = assignType(cpy.DefDef(ddef)(name, paramss1, tpt1, rhs1), sym)
25362559

2537-
postProcessInfo(sym)
2560+
postProcessInfo(ddef2, sym)
25382561
ddef2.setDefTree
25392562
//todo: make sure dependent method types do not depend on implicits or by-name params
25402563
}
25412564

25422565
/** (1) Check that the signature of the class member does not return a repeated parameter type
25432566
* (2) If info is an erased class, set erased flag of member
2567+
* (3) Check that erased classes are not parameters of polymorphic functions.
25442568
*/
2545-
private def postProcessInfo(sym: Symbol)(using Context): Unit =
2569+
private def postProcessInfo(mdef: MemberDef, sym: Symbol)(using Context): Unit =
25462570
if (!sym.isOneOf(Synthetic | InlineProxy | Param) && sym.info.finalResultType.isRepeatedParam)
25472571
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
25482572
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
25492573
sym.setFlag(Erased)
2574+
if
2575+
sym.info.isInstanceOf[PolyType] &&
2576+
((sym.name eq nme.ANON_FUN) ||
2577+
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
2578+
then
2579+
mdef match
2580+
case DefDef(_, _ :: vparams :: Nil, _, _) =>
2581+
vparams.foreach: vparam =>
2582+
if vparam.symbol.is(Erased) then
2583+
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
2584+
case _ =>
25502585

25512586
def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
25522587
val TypeDef(name, rhs) = tdef
@@ -2693,19 +2728,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
26932728
// check value class constraints
26942729
checkDerivedValueClass(cls, body1)
26952730

2696-
// check PolyFunction constraints (no erased functions!)
2697-
if parents1.exists(_.tpe.classSymbol eq defn.PolyFunctionClass) then
2698-
body1.foreach {
2699-
case ddef: DefDef =>
2700-
ddef.paramss.foreach { params =>
2701-
val erasedParam = params.collectFirst { case vdef: ValDef if vdef.symbol.is(Erased) => vdef }
2702-
erasedParam.foreach { p =>
2703-
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", p.srcPos)
2704-
}
2705-
}
2706-
case _ =>
2707-
}
2708-
27092731
val effectiveOwner = cls.owner.skipWeakOwner
27102732
if !cls.isRefinementClass
27112733
&& !cls.isAllOf(PrivateLocal)
@@ -3057,6 +3079,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
30573079
case tree: untpd.Block => typedBlock(desugar.block(tree), pt)(using ctx.fresh.setNewScope)
30583080
case tree: untpd.If => typedIf(tree, pt)
30593081
case tree: untpd.Function => typedFunction(tree, pt)
3082+
case tree: untpd.PolyFunction => typedPolyFunction(tree, pt)
30603083
case tree: untpd.Closure => typedClosure(tree, pt)
30613084
case tree: untpd.Import => typedImport(tree)
30623085
case tree: untpd.Export => typedExport(tree)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:53 ---------------------------------------------
1+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-functions1.scala:1:33 ---------------------------------------------
22
1 |val f: [T] => (x: T) => x.type = [T] => (x: Int) => x // error
3-
| ^
4-
| Found: [T] => (x: Int) => x.type
5-
| Required: [T] => (x: T) => x.type
3+
| ^^^^^^^^^^^^^^^^^^^^
4+
| Found: [T] => (x: Int) => x.type
5+
| Required: [T] => (x: T) => x.type
66
|
77
| longer explanation available when compiling with `-explain`

0 commit comments

Comments
 (0)