Skip to content

fix: show correctly typed hole on applyDynamic #23420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.annotation.tailrec
import dotty.tools.dotc.core.Denotations.SingleDenotation
import dotty.tools.dotc.core.Denotations.MultiDenotation
import dotty.tools.dotc.util.Spans.Span
import dotty.tools.dotc.core.Symbols

object ApplyExtractor:
def unapply(path: List[Tree])(using Context): Option[Apply] =
Expand All @@ -44,8 +45,10 @@ object ApplyExtractor:


object ApplyArgsExtractor:
// normally symbol but for refinment types method type
type Method = Symbol | Type
def getArgsAndParams(
optIndexedContext: Option[IndexedContext],
indexedContext: IndexedContext,
apply: Apply,
span: Span
)(using Context): List[(List[Tree], List[ParamSymbol])] =
Expand Down Expand Up @@ -78,47 +81,56 @@ object ApplyArgsExtractor:

// fallback for when multiple overloaded methods match the supplied args
def fallbackFindMatchingMethods() =
def matchingMethodsSymbols(
indexedContext: IndexedContext,
method: Tree
): List[Symbol] =
def matchingMethodsSymbols(method: Tree): List[Method] =
method match
case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil)
case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil)
case Select(t @ This(_), name) =>
val res = indexedContext.findSymbol(name).getOrElse(Nil).filter(_.exists)
res ++ findRefinments(t.symbol.info, name)
case sel @ Select(from, name) =>
val symbol = from.symbol
val ownerSymbol =
if symbol.is(Method) && symbol.owner.isClass then
Some(symbol.owner)
else Try(symbol.info.classSymbol).toOption
ownerSymbol.map(sym => sym.info.member(name)).collect{
val res = ownerSymbol.map(sym => sym.info.member(name)).collect{
case single: SingleDenotation => List(single.symbol)
case multi: MultiDenotation => multi.allSymbols
}.getOrElse(Nil)
case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun)
res ++ findRefinments(symbol.info, name)
case Apply(fun, _) => matchingMethodsSymbols(fun)
case TypeApply(fun, args) =>
matchingMethodsSymbols(fun).map {
case t: PolyType => t.appliedTo(args.map(_.tpe))
case s => s
}
case _ => Nil
val matchingMethods =
for
indexedContext <- optIndexedContext.toList
potentialMatch <- matchingMethodsSymbols(indexedContext, method)
if potentialMatch.is(Flags.Method) &&
potentialMatch.vparamss.length >= argss.length &&
Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption
.getOrElse(false) &&
potentialMatch.vparamss
potentialMatch <- matchingMethodsSymbols(method)
if potentialMatch match
case s: Symbol => s.is(Flags.Method) && Try(s.isAccessibleFrom(apply.symbol.info)).toOption.getOrElse(false)
case _ => true
if potentialMatch.vparamss.length >= argss.length &&
(potentialMatch match {
case s: Symbol =>
s.symVparamss
.zip(argss)
.reverse
.zipWithIndex
.forall { case (pair, index) =>
FuzzyArgMatcher(potentialMatch.tparams)
FuzzyArgMatcher(s.symTparams)
.doMatch(allArgsProvided = index != 0, span)
.tupled(pair)
}
case _ => true
})

yield potentialMatch
matchingMethods
end fallbackFindMatchingMethods

val matchingMethods: List[Symbol] =
val matchingMethods: List[Method] =
if method.symbol.paramSymss.nonEmpty then
val allArgsAreSupplied =
val vparamss = method.symbol.vparamss
Expand Down Expand Up @@ -157,11 +169,10 @@ object ApplyArgsExtractor:
// def curry(x: Int)(apple: String, banana: String) = ???
// curry(1)(apple = "test", b@@)
// ```
val (baseParams0, baseArgs) =
val (defaultBaseParams, baseArgs) =
vparamss.zip(argss).lastOption.getOrElse((Nil, Nil))

val baseParams: List[ParamSymbol] =
def defaultBaseParams = baseParams0.map(JustSymbol(_))
@tailrec
def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] =
if level > 0 then
Expand All @@ -176,8 +187,8 @@ object ApplyArgsExtractor:
else
refinedType match
case RefinedType(AppliedType(_, args), _, MethodType(ri)) =>
baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) =>
RefinedSymbol(sym, name, arg)
defaultBaseParams.zip(ri).zip(args).map { case ((sym, name), arg) =>
RefinedSymbol(sym.symbol, name, arg)
}
case _ => defaultBaseParams
// finds param refinements for lambda expressions
Expand All @@ -198,11 +209,35 @@ object ApplyArgsExtractor:
(baseArgs, baseParams)
}

extension (method: Symbol)
def vparamss(using Context) = method.filteredParamss(_.isTerm)
def tparams(using Context) = method.filteredParamss(_.isType).flatten
def filteredParamss(f: Symbol => Boolean)(using Context) =
method.paramSymss.filter(params => params.forall(f))
@tailrec
private def findRefinments(tpe: Type, name: Name, acc: List[Method] = Nil): List[Method] =
tpe match
case RefinedType(parent, `name`, refinedInfo) =>
findRefinments(parent, name, refinedInfo :: acc)
case RefinedType(parent, _, s) => findRefinments(parent, name, acc)
case _ => acc.reverse


extension (method: Method)
def vparamss(using Context): List[List[ParamSymbol]] =
method match
case s: Symbol => s.symVparamss.map(_.map(JustSymbol(_)))
case m: MethodType =>
m.paramInfoss.zipWithIndex.map {
case (params, idx) =>
params.zip(m.paramNamess.get(idx).getOrElse(Nil)).map{
case (tpe, name) => RefinedSymbol(Symbols.NoSymbol, name, tpe)
}
}
case _ => Nil

extension (sym: Symbol)
def symVparamss(using Context): List[List[Symbol]] = filteredParamss(sym, _.isTerm)

def symTparams(using Context): List[Symbol] = filteredParamss(sym, _.isType).flatten

private def filteredParamss(s: Symbol, f: Symbol => Boolean)(using Context): List[List[Symbol]] =
s.paramSymss.filter(params => params.forall(f))
sealed trait ParamSymbol:
def name: Name
def info: Type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ class InferExpectedType(
val indexedCtx = IndexedContext(pos)(using locatedCtx)
val printer =
ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx)
InterCompletionType.inferType(path)(using newctx).map{
InferCompletionType.inferType(path)(using newctx).map{
tpe => printer.tpe(tpe)
}
case None => None

object InterCompletionType:
object InferCompletionType:
def inferType(path: List[Tree])(using Context): Option[Type] =
path match
case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span)
Expand Down Expand Up @@ -94,7 +94,7 @@ object InterCompletionType:
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
// f(@@)
case ApplyExtractor(app) =>
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(IndexedContext.Empty, app, span).headOption
argsAndParams.flatMap:
case (args, params) =>
val idx = args.indexWhere(_.span.contains(span))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ class Completions(
config.isCompletionSnippetsEnabled()
)
(args, false)
val singletonCompletions = InterCompletionType.inferType(path).map(
val singletonCompletions = InferCompletionType.inferType(path).map(
SingletonCompletions.contribute(path, _, completionPos)
).getOrElse(Nil)
(singletonCompletions ++ advanced, exclusive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object NamedArgCompletions:
case _ => false

val argsAndParams = ApplyArgsExtractor.getArgsAndParams(
Some(indexedContext),
indexedContext,
apply,
ident.span
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,48 @@ class InferExpectedTypeSuite extends BasePCSuite:
"""|String
|""".stripMargin
)

@Test def `apply-dynamic` =
check(
"""|object TypedHoleApplyDynamic {
| val obj: reflect.Selectable {
| def method(x: Int): Unit
| } = new reflect.Selectable {
| def method(x: Int): Unit = ()
| }
|
| obj.method(@@)
|}
|""".stripMargin,
"Int"
)

@Test def `apply-dynamic-2` =
check(
"""|object TypedHoleApplyDynamic {
| val obj: reflect.Selectable {
| def method[T](x: Int, y: T): Unit
| } = new reflect.Selectable {
| def method[T](x: Int, y: T): Unit = ()
| }
|
| obj.method[Int](1, @@)
|}
|""".stripMargin,
"Int"
)

@Test def `apply-dynamic-3` =
check(
"""|object TypedHoleApplyDynamic {
| val obj: reflect.Selectable {
| def method[T](a: Int)(x: Int, y: T): Unit
| } = new reflect.Selectable {
| def method[T](a: Int)(x: Int, y: T): Unit = ()
| }
|
| obj.method[String](1)(1, @@)
|}
|""".stripMargin,
"String"
)
Loading