Skip to content

Commit 2941551

Browse files
committed
Fix paramCount
1 parent 7f32f50 commit 2941551

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

compiler/src/dotty/tools/dotc/core/Decorators.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ object Decorators {
114114
else x1 :: xs1
115115
}
116116

117+
/** Like `xs.lazyZip(xs.indices).map(f)`, but returns list `xs` itself
118+
* - instead of a copy - if function `f` maps all elements of
119+
* `xs` to themselves.
120+
*/
121+
def mapWithIndexConserve[U](f: (T, Int) => T): List[T] =
122+
def recur(xs: List[T], idx: Int): List[T] =
123+
if xs.isEmpty then Nil
124+
else
125+
val x1 = f(xs.head, idx)
126+
val xs1 = recur(xs.tail, idx + 1)
127+
if (x1.asInstanceOf[AnyRef] eq xs.head.asInstanceOf[AnyRef])
128+
&& (xs1 eq xs.tail)
129+
then xs
130+
else x1 :: xs1
131+
recur(xs, 0)
132+
117133
final def hasSameLengthAs[U](ys: List[U]): Boolean = {
118134
@tailrec def loop(xs: List[T], ys: List[U]): Boolean =
119135
if (xs.isEmpty) ys.isEmpty

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,27 +1656,39 @@ trait Applications extends Compatibility {
16561656
alts.filter(sizeFits(_))
16571657

16581658
def narrowByShapes(alts: List[TermRef]): List[TermRef] =
1659-
1659+
16601660
/** Normalization steps before shape-checking arguments:
16611661
*
16621662
* { expr } --> expr
16631663
* (x1, ..., xn) => expr --> ((x1, ..., xn)) => expr
1664-
* if n > 1, no alternative takes `n` parameters,
1665-
* and at least one alternative takes 1 parameter.
1664+
* if n > 1, no alternative has a corresponding formal parameter that
1665+
* is an n-ary function, and at least one alternative has a corresponding
1666+
* formal parameter that is a unary function.
16661667
*/
1667-
def normArg(arg: untpd.Tree): untpd.Tree = arg match
1668-
case Block(Nil, expr) => normArg(expr)
1669-
case x @ untpd.Function(args, body) =>
1668+
def normArg(arg: untpd.Tree, idx: Int): untpd.Tree = arg match
1669+
case Block(Nil, expr) => normArg(expr, idx)
1670+
case untpd.Function(args, body) =>
1671+
1672+
// If ref refers to a method whose parameter at index `idx` is a function type,
1673+
// the arity of that function, otherise 0.
1674+
def paramCount(ref: TermRef) =
1675+
val formals = ref.widen.firstParamTypes
1676+
if formals.length > idx then
1677+
formals(idx) match
1678+
case defn.FunctionOf(args, _, _, _) => args.length
1679+
case _ => 0
1680+
else 0
1681+
16701682
val numArgs = args.length
1671-
def paramCount(ref: TermRef) = ref.widen.firstParamTypes.length
16721683
if numArgs > 1
16731684
&& !alts.exists(paramCount(_) == numArgs)
16741685
&& alts.exists(paramCount(_) == 1)
16751686
then untpd.Function(untpd.Tuple(args) :: Nil, body)
16761687
else arg
16771688
case _ => arg
1678-
1679-
val normArgs = args.mapConserve(normArg)
1689+
end normArg
1690+
1691+
val normArgs = args.mapWithIndexConserve(normArg)
16801692
if (normArgs exists untpd.isFunctionWithUnknownParamType)
16811693
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
16821694
else narrowByTypes(alts, normArgs map typeShape, resultType)

0 commit comments

Comments
 (0)