Skip to content

Commit 7a683ab

Browse files
authored
Make polymorphic functions more efficient and expressive (#17548)
This PR enhances polymorphic function types in two ways: - Dependent result types can now be inferred from the expected type - polymorphic lambdas are now implemented using JVM lambdas when possible instead of anonymous classes. Additionally, we fix the logic for renaming bound variables when pretty-printing lambdas and fix the handling of `this` in refinements.
2 parents a8e9312 + d7a345f commit 7a683ab

25 files changed

+248
-116
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,40 @@ object desugar {
10611061
name
10621062
}
10631063

1064+
/** Strip parens and empty blocks around the body of `tree`. */
1065+
def normalizePolyFunction(tree: PolyFunction)(using Context): PolyFunction =
1066+
def stripped(body: Tree): Tree = body match
1067+
case Parens(body1) =>
1068+
stripped(body1)
1069+
case Block(Nil, body1) =>
1070+
stripped(body1)
1071+
case _ => body
1072+
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
1073+
1074+
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1075+
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1076+
*/
1077+
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
1078+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1079+
val funFlags = fun match
1080+
case fun: FunctionWithMods =>
1081+
fun.mods.flags
1082+
case _ => EmptyFlags
1083+
1084+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1085+
// val isImpure = funFlags.is(Impure)
1086+
1087+
// Function flags to be propagated to each parameter in the desugared method type.
1088+
val paramFlags = funFlags.toTermFlags & Given
1089+
val vparams = vparamTypes.zipWithIndex.map:
1090+
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1091+
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1092+
1093+
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1094+
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
1095+
)).withSpan(tree.span)
1096+
end makePolyFunctionType
1097+
10641098
/** Invent a name for an anonympus given of type or template `impl`. */
10651099
def inventGivenOrExtensionName(impl: Tree)(using Context): SimpleName =
10661100
val str = impl match
@@ -1454,17 +1488,20 @@ object desugar {
14541488
}
14551489

14561490
/** Make closure corresponding to function.
1457-
* params => body
1491+
* [tparams] => params => body
14581492
* ==>
1459-
* def $anonfun(params) = body
1493+
* def $anonfun[tparams](params) = body
14601494
* Closure($anonfun)
14611495
*/
1462-
def makeClosure(params: List[ValDef], body: Tree, tpt: Tree | Null = null, isContextual: Boolean, span: Span)(using Context): Block =
1496+
def makeClosure(tparams: List[TypeDef], vparams: List[ValDef], body: Tree, tpt: Tree | Null = null, span: Span)(using Context): Block =
1497+
val paramss: List[ParamClause] =
1498+
if tparams.isEmpty then vparams :: Nil
1499+
else tparams :: vparams :: Nil
14631500
Block(
1464-
DefDef(nme.ANON_FUN, params :: Nil, if (tpt == null) TypeTree() else tpt, body)
1501+
DefDef(nme.ANON_FUN, paramss, if (tpt == null) TypeTree() else tpt, body)
14651502
.withSpan(span)
14661503
.withMods(synthetic | Artifact),
1467-
Closure(Nil, Ident(nme.ANON_FUN), if (isContextual) ContextualEmptyTree else EmptyTree))
1504+
Closure(Nil, Ident(nme.ANON_FUN), EmptyTree))
14681505

14691506
/** If `nparams` == 1, expand partial function
14701507
*
@@ -1753,62 +1790,6 @@ object desugar {
17531790
}
17541791
}
17551792

1756-
def makePolyFunction(targs: List[Tree], body: Tree, pt: Type): Tree = body match {
1757-
case Parens(body1) =>
1758-
makePolyFunction(targs, body1, pt)
1759-
case Block(Nil, body1) =>
1760-
makePolyFunction(targs, body1, pt)
1761-
case Function(vargs, res) =>
1762-
assert(targs.nonEmpty)
1763-
// TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1764-
val mods = body match {
1765-
case body: FunctionWithMods => body.mods
1766-
case _ => untpd.EmptyModifiers
1767-
}
1768-
val polyFunctionTpt = ref(defn.PolyFunctionType)
1769-
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1770-
if (ctx.mode.is(Mode.Type)) {
1771-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1772-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1773-
1774-
val applyVParams = vargs.zipWithIndex.map {
1775-
case (p: ValDef, _) => p.withAddedFlags(mods.flags)
1776-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(mods.flags.toTermFlags)
1777-
}
1778-
RefinedTypeTree(polyFunctionTpt, List(
1779-
DefDef(nme.apply, applyTParams :: applyVParams :: Nil, res, EmptyTree).withFlags(Synthetic)
1780-
))
1781-
}
1782-
else {
1783-
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1784-
// with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1785-
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1786-
// where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1787-
1788-
def typeTree(tp: Type) = tp match
1789-
case RefinedType(parent, nme.apply, PolyType(_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1790-
var bail = false
1791-
def mapper(tp: Type, topLevel: Boolean = false): Tree = tp match
1792-
case tp: TypeRef => ref(tp)
1793-
case tp: TypeParamRef => Ident(applyTParams(tp.paramNum).name)
1794-
case AppliedType(tycon, args) => AppliedTypeTree(mapper(tycon), args.map(mapper(_)))
1795-
case _ => if topLevel then TypeTree() else { bail = true; genericEmptyTree }
1796-
val mapped = mapper(mt.resultType, topLevel = true)
1797-
if bail then TypeTree() else mapped
1798-
case _ => TypeTree()
1799-
1800-
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1801-
.map(varg => varg.withAddedFlags(mods.flags | Param))
1802-
New(Template(emptyConstructor, List(polyFunctionTpt), Nil, EmptyValDef,
1803-
List(DefDef(nme.apply, applyTParams :: applyVParams :: Nil, typeTree(pt), res))
1804-
))
1805-
}
1806-
case _ =>
1807-
// may happen for erroneous input. An error will already have been reported.
1808-
assert(ctx.reporter.errorsReported)
1809-
EmptyTree
1810-
}
1811-
18121793
// begin desugar
18131794

18141795
// Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1821,8 +1802,6 @@ object desugar {
18211802
}
18221803

18231804
val desugared = tree match {
1824-
case PolyFunction(targs, body) =>
1825-
makePolyFunction(targs, body, pt) orElse tree
18261805
case SymbolLit(str) =>
18271806
Apply(
18281807
ref(defn.ScalaSymbolClass.companionModule.termRef),

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,7 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
420420
case Closure(_, meth, _) => true
421421
case Block(Nil, expr) => isContextualClosure(expr)
422422
case Block(DefDef(nme.ANON_FUN, params :: _, _, _) :: Nil, cl: Closure) =>
423-
if params.isEmpty then
424-
cl.tpt.eq(untpd.ContextualEmptyTree) || defn.isContextFunctionType(cl.tpt.typeOpt)
425-
else
426-
isUsingClause(params)
423+
isUsingClause(params)
427424
case _ => false
428425
}
429426

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,6 @@ object Trees {
11921192

11931193
@sharable val EmptyTree: Thicket = genericEmptyTree
11941194
@sharable val EmptyValDef: ValDef = genericEmptyValDef
1195-
@sharable val ContextualEmptyTree: Thicket = new EmptyTree() // an empty tree marking a contextual closure
11961195

11971196
// ----- Auxiliary creation methods ------------------
11981197

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
151151
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
152152

153153
/** Short-lived usage in typer, does not need copy/transform/fold infrastructure */
154-
case class DependentTypeTree(tp: List[Symbol] => Type)(implicit @constructorOnly src: SourceFile) extends Tree
154+
case class DependentTypeTree(tp: (List[TypeSymbol], List[TermSymbol]) => Type)(implicit @constructorOnly src: SourceFile) extends Tree
155155

156156
@sharable object EmptyTypeIdent extends Ident(tpnme.EMPTY)(NoSource) with WithoutTypeOrPos[Untyped] {
157157
override def isEmpty: Boolean = true

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,12 @@ object NameOps {
236236
*/
237237
def isPlainFunction(using Context): Boolean = functionArity >= 0
238238

239-
/** Is a function name that contains `mustHave` as a substring */
240-
private def isSpecificFunction(mustHave: String)(using Context): Boolean =
239+
/** Is a function name that contains `mustHave` as a substring
240+
* and has arity `minArity` or greater.
241+
*/
242+
private def isSpecificFunction(mustHave: String, minArity: Int = 0)(using Context): Boolean =
241243
val suffixStart = functionSuffixStart
242-
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= 0
244+
isFunctionPrefix(suffixStart, mustHave) && funArity(suffixStart) >= minArity
243245

244246
def isContextFunction(using Context): Boolean = isSpecificFunction("Context")
245247
def isImpureFunction(using Context): Boolean = isSpecificFunction("Impure")

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ object Types {
18721872
if alwaysDependent || mt.isResultDependent then
18731873
RefinedType(funType, nme.apply, mt)
18741874
else funType
1875+
case poly @ PolyType(_, mt: MethodType) if !mt.isParamDependent =>
1876+
RefinedType(defn.PolyFunctionType, nme.apply, poly)
18751877
}
18761878

18771879
/** The signature of this type. This is by default NotAMethod,

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,7 @@ object Parsers {
15111511
TermLambdaTypeTree(params.asInstanceOf[List[ValDef]], resultType)
15121512
else if imods.isOneOf(Given | Impure) || erasedArgs.contains(true) then
15131513
if imods.is(Given) && params.isEmpty then
1514+
imods &~= Given
15141515
syntaxError(em"context function types require at least one parameter", paramSpan)
15151516
FunctionWithMods(params, resultType, imods, erasedArgs.toList)
15161517
else if !ctx.settings.YkindProjector.isDefault then

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
297297

298298
protected def paramsText(lam: LambdaType): Text = {
299299
val erasedParams = lam.erasedParams
300-
def paramText(name: Name, tp: Type, erased: Boolean) =
301-
keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true)
302-
Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ")
300+
def paramText(ref: ParamRef, erased: Boolean) =
301+
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true)
302+
Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ")
303303
}
304304

305305
protected def ParamRefNameString(name: Name): String = nameString(name)
@@ -363,7 +363,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
363363
case tp @ ConstantType(value) =>
364364
toText(value)
365365
case pref: TermParamRef =>
366-
nameString(pref.binder.paramNames(pref.paramNum)) ~ lambdaHash(pref.binder)
366+
ParamRefNameString(pref) ~ lambdaHash(pref.binder)
367367
case tp: RecThis =>
368368
val idx = openRecs.reverse.indexOf(tp.binder)
369369
if (idx >= 0) selfRecName(idx + 1)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
174174
~ " " ~ argText(args.last)
175175
}
176176

177-
private def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
177+
protected def toTextMethodAsFunction(info: Type, isPure: Boolean, refs: Text = Str("")): Text = info match
178178
case info: MethodType =>
179179
val capturesRoot = refs == rootSetText
180180
changePrec(GlobalPrec) {

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
196196
case AmbiguousExtensionMethodID // errorNumber 180
197197
case UnqualifiedCallToAnyRefMethodID // errorNumber: 181
198198
case NotConstantID // errorNumber: 182
199+
case ClosureCannotHaveInternalParameterDependenciesID // errorNumber: 183
199200

200201
def errorNumber = ordinal - 1
201202

0 commit comments

Comments
 (0)