Skip to content

Commit 63aef37

Browse files
Merge pull request #7979 from dotty-staging/fix-#7964
Fix #7964: Allow inline enum parameters
2 parents c79ffa1 + f210ca1 commit 63aef37

File tree

5 files changed

+44
-4
lines changed

5 files changed

+44
-4
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package dotty.tools.dotc
22
package transform
33

44
import java.io.{PrintWriter, StringWriter}
5-
import java.lang.reflect.{InvocationTargetException, Method}
5+
import java.lang.reflect.{InvocationTargetException, Method => JLRMethod}
66

77
import dotty.tools.dotc.ast.tpd
88
import dotty.tools.dotc.ast.Trees._
@@ -194,10 +194,13 @@ object Splicer {
194194
interpretNew(fn.symbol, args.flatten.map(interpretTree))
195195
else if (fn.symbol.is(Module))
196196
interpretModuleAccess(fn.symbol)
197-
else if (fn.symbol.isStatic) {
197+
else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
198198
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
199199
staticMethodCall(args.flatten.map(interpretTree))
200200
}
201+
else if (fn.symbol.isStatic)
202+
assert(args.isEmpty)
203+
interpretedStaticFieldAccess(fn.symbol)
201204
else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic)
202205
if (fn.name == nme.asInstanceOfPM)
203206
interpretModuleAccess(fn.qualifier.symbol)
@@ -277,6 +280,12 @@ object Splicer {
277280
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method)
278281
}
279282

283+
private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = {
284+
val clazz = loadClass(sym.owner.fullName.toString)
285+
val field = clazz.getField(sym.name.toString)
286+
field.get(null)
287+
}
288+
280289
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
281290
loadModule(fn.moduleClass)
282291

@@ -319,15 +328,15 @@ object Splicer {
319328
throw new StopInterpretation(msg, pos)
320329
}
321330

322-
private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): Method =
331+
private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod =
323332
try clazz.getMethod(name.toString, paramClasses: _*)
324333
catch {
325334
case _: NoSuchMethodException =>
326335
val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)"
327336
throw new StopInterpretation(msg, pos)
328337
}
329338

330-
private def stopIfRuntimeException[T](thunk: => T, method: Method): T =
339+
private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T =
331340
try thunk
332341
catch {
333342
case ex: RuntimeException =>

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,10 +835,13 @@ trait Checking {
835835
def isCaseObject(sym: Symbol): Boolean =
836836
// TODO add alias to Nil in scala package
837837
sym.is(Case) && sym.is(Module)
838+
def isStaticEnumCase(sym: Symbol): Boolean =
839+
sym.is(Enum) && sym.is(JavaStatic) && sym.is(Case)
838840
val allow =
839841
ctx.erasedTypes ||
840842
ctx.inInlineMethod ||
841843
(tree.symbol.isStatic && isCaseObject(tree.symbol) || isCaseClassApply(tree.symbol)) ||
844+
isStaticEnumCase(tree.symbol) ||
842845
isCaseClassNew(tree.symbol)
843846

844847
if (!allow) ctx.error(em"$what must be a known value", tree.sourcePos)

tests/pos/i7964.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
enum Nums { case One; case Two }
2+
3+
def fok(that: Nums) = ???
4+
val ok = fok(Nums.One)
5+
6+
inline def fko(inline that: Nums) = inline that match {
7+
case Nums.One => "fff(one)"
8+
}
9+
val ko = fko(Nums.One)

tests/run-macros/i7964/Macro_1.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.quoted._
2+
3+
enum Num {
4+
case One
5+
case Two
6+
}
7+
8+
inline def foo(inline num: Num): Int = ${ fooExpr(num) }
9+
10+
private def fooExpr(num: Num)(given QuoteContext): Expr[Int] = Expr(toInt(num))
11+
12+
private def toInt(num: Num): Int = num match {
13+
case Num.One => 1
14+
case Num.Two => 2
15+
}

tests/run-macros/i7964/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
object Test extends App {
2+
assert(foo(Num.One) == 1)
3+
assert(foo(Num.Two) == 2)
4+
}

0 commit comments

Comments
 (0)