diff --git a/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala b/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala index 9384a0b43e8b..4617b76181a4 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala @@ -20,6 +20,7 @@ import scala.annotation.tailrec import dotty.tools.dotc.core.Denotations.SingleDenotation import dotty.tools.dotc.core.Denotations.MultiDenotation import dotty.tools.dotc.util.Spans.Span +import dotty.tools.dotc.core.Symbols object ApplyExtractor: def unapply(path: List[Tree])(using Context): Option[Apply] = @@ -44,8 +45,10 @@ object ApplyExtractor: object ApplyArgsExtractor: + // normally symbol but for refinment types method type + type Method = Symbol | Type def getArgsAndParams( - optIndexedContext: Option[IndexedContext], + indexedContext: IndexedContext, apply: Apply, span: Span )(using Context): List[(List[Tree], List[ParamSymbol])] = @@ -78,47 +81,56 @@ object ApplyArgsExtractor: // fallback for when multiple overloaded methods match the supplied args def fallbackFindMatchingMethods() = - def matchingMethodsSymbols( - indexedContext: IndexedContext, - method: Tree - ): List[Symbol] = + def matchingMethodsSymbols(method: Tree): List[Method] = method match case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil) - case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil) + case Select(t @ This(_), name) => + val res = indexedContext.findSymbol(name).getOrElse(Nil).filter(_.exists) + res ++ findRefinments(t.symbol.info, name) case sel @ Select(from, name) => val symbol = from.symbol val ownerSymbol = if symbol.is(Method) && symbol.owner.isClass then Some(symbol.owner) else Try(symbol.info.classSymbol).toOption - ownerSymbol.map(sym => sym.info.member(name)).collect{ + val res = ownerSymbol.map(sym => sym.info.member(name)).collect{ case single: SingleDenotation => List(single.symbol) case multi: MultiDenotation => multi.allSymbols }.getOrElse(Nil) - case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun) + res ++ findRefinments(symbol.info, name) + case Apply(fun, _) => matchingMethodsSymbols(fun) + case TypeApply(fun, args) => + matchingMethodsSymbols(fun).map { + case t: PolyType => t.appliedTo(args.map(_.tpe)) + case s => s + } case _ => Nil val matchingMethods = for - indexedContext <- optIndexedContext.toList - potentialMatch <- matchingMethodsSymbols(indexedContext, method) - if potentialMatch.is(Flags.Method) && - potentialMatch.vparamss.length >= argss.length && - Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption - .getOrElse(false) && - potentialMatch.vparamss + potentialMatch <- matchingMethodsSymbols(method) + if potentialMatch match + case s: Symbol => s.is(Flags.Method) && Try(s.isAccessibleFrom(apply.symbol.info)).toOption.getOrElse(false) + case _ => true + if potentialMatch.vparamss.length >= argss.length && + (potentialMatch match { + case s: Symbol => + s.symVparamss .zip(argss) .reverse .zipWithIndex .forall { case (pair, index) => - FuzzyArgMatcher(potentialMatch.tparams) + FuzzyArgMatcher(s.symTparams) .doMatch(allArgsProvided = index != 0, span) .tupled(pair) } + case _ => true + }) + yield potentialMatch matchingMethods end fallbackFindMatchingMethods - val matchingMethods: List[Symbol] = + val matchingMethods: List[Method] = if method.symbol.paramSymss.nonEmpty then val allArgsAreSupplied = val vparamss = method.symbol.vparamss @@ -157,11 +169,10 @@ object ApplyArgsExtractor: // def curry(x: Int)(apple: String, banana: String) = ??? // curry(1)(apple = "test", b@@) // ``` - val (baseParams0, baseArgs) = + val (defaultBaseParams, baseArgs) = vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) val baseParams: List[ParamSymbol] = - def defaultBaseParams = baseParams0.map(JustSymbol(_)) @tailrec def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] = if level > 0 then @@ -176,8 +187,8 @@ object ApplyArgsExtractor: else refinedType match case RefinedType(AppliedType(_, args), _, MethodType(ri)) => - baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) => - RefinedSymbol(sym, name, arg) + defaultBaseParams.zip(ri).zip(args).map { case ((sym, name), arg) => + RefinedSymbol(sym.symbol, name, arg) } case _ => defaultBaseParams // finds param refinements for lambda expressions @@ -198,11 +209,35 @@ object ApplyArgsExtractor: (baseArgs, baseParams) } - extension (method: Symbol) - def vparamss(using Context) = method.filteredParamss(_.isTerm) - def tparams(using Context) = method.filteredParamss(_.isType).flatten - def filteredParamss(f: Symbol => Boolean)(using Context) = - method.paramSymss.filter(params => params.forall(f)) + @tailrec + private def findRefinments(tpe: Type, name: Name, acc: List[Method] = Nil): List[Method] = + tpe match + case RefinedType(parent, `name`, refinedInfo) => + findRefinments(parent, name, refinedInfo :: acc) + case RefinedType(parent, _, s) => findRefinments(parent, name, acc) + case _ => acc.reverse + + + extension (method: Method) + def vparamss(using Context): List[List[ParamSymbol]] = + method match + case s: Symbol => s.symVparamss.map(_.map(JustSymbol(_))) + case m: MethodType => + m.paramInfoss.zipWithIndex.map { + case (params, idx) => + params.zip(m.paramNamess.get(idx).getOrElse(Nil)).map{ + case (tpe, name) => RefinedSymbol(Symbols.NoSymbol, name, tpe) + } + } + case _ => Nil + + extension (sym: Symbol) + def symVparamss(using Context): List[List[Symbol]] = filteredParamss(sym, _.isTerm) + + def symTparams(using Context): List[Symbol] = filteredParamss(sym, _.isType).flatten + + private def filteredParamss(s: Symbol, f: Symbol => Boolean)(using Context): List[List[Symbol]] = + s.paramSymss.filter(params => params.forall(f)) sealed trait ParamSymbol: def name: Name def info: Type diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala index 2e6c7b39ba65..3a314228f264 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala @@ -50,12 +50,12 @@ class InferExpectedType( val indexedCtx = IndexedContext(pos)(using locatedCtx) val printer = ShortenedTypePrinter(search, IncludeDefaultParam.ResolveLater)(using indexedCtx) - InterCompletionType.inferType(path)(using newctx).map{ + InferCompletionType.inferType(path)(using newctx).map{ tpe => printer.tpe(tpe) } case None => None -object InterCompletionType: +object InferCompletionType: def inferType(path: List[Tree])(using Context): Option[Type] = path match case (lit: Literal) :: Select(Literal(_), _) :: Apply(Select(Literal(_), _), List(s: Select)) :: rest if s.symbol == defn.Predef_undefined => inferType(rest, lit.span) @@ -94,7 +94,7 @@ object InterCompletionType: else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind)) // f(@@) case ApplyExtractor(app) => - val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption + val argsAndParams = ApplyArgsExtractor.getArgsAndParams(IndexedContext.Empty, app, span).headOption argsAndParams.flatMap: case (args, params) => val idx = args.indexWhere(_.span.contains(span)) diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala index 6e79f5a293e5..0086f2cf5488 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/Completions.scala @@ -520,7 +520,7 @@ class Completions( config.isCompletionSnippetsEnabled() ) (args, false) - val singletonCompletions = InterCompletionType.inferType(path).map( + val singletonCompletions = InferCompletionType.inferType(path).map( SingletonCompletions.contribute(path, _, completionPos) ).getOrElse(Nil) (singletonCompletions ++ advanced, exclusive) diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala index faf6d715d8cf..88b5abc171fc 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala @@ -74,7 +74,7 @@ object NamedArgCompletions: case _ => false val argsAndParams = ApplyArgsExtractor.getArgsAndParams( - Some(indexedContext), + indexedContext, apply, ident.span ) diff --git a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala index ba96488471b6..42297de62560 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala @@ -335,3 +335,48 @@ class InferExpectedTypeSuite extends BasePCSuite: """|String |""".stripMargin ) + + @Test def `apply-dynamic` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method(x: Int): Unit + | } = new reflect.Selectable { + | def method(x: Int): Unit = () + | } + | + | obj.method(@@) + |} + |""".stripMargin, + "Int" + ) + + @Test def `apply-dynamic-2` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method[T](x: Int, y: T): Unit + | } = new reflect.Selectable { + | def method[T](x: Int, y: T): Unit = () + | } + | + | obj.method[Int](1, @@) + |} + |""".stripMargin, + "Int" + ) + + @Test def `apply-dynamic-3` = + check( + """|object TypedHoleApplyDynamic { + | val obj: reflect.Selectable { + | def method[T](a: Int)(x: Int, y: T): Unit + | } = new reflect.Selectable { + | def method[T](a: Int)(x: Int, y: T): Unit = () + | } + | + | obj.method[String](1)(1, @@) + |} + |""".stripMargin, + "String" + )