Skip to content

Commit 2bcd4f6

Browse files
authored
Merge pull request #9018 from dotty-staging/fix-#9011b
Fix #9011: Make single enum values inherit from Product
2 parents 484e3c6 + 74cc1b1 commit 2bcd4f6

File tree

22 files changed

+200
-42
lines changed

22 files changed

+200
-42
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ object DesugarEnums {
124124

125125
/** A creation method for a value of enum type `E`, which is defined as follows:
126126
*
127-
* private def $new(_$ordinal: Int, $name: String) = new E {
127+
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
128128
* def $ordinal = $tag
129129
* override def toString = $name
130130
* $values.register(this)
@@ -135,7 +135,7 @@ object DesugarEnums {
135135
val toStringDef = toStringMeth(Ident(nme.nameDollar))
136136
val creator = New(Template(
137137
constr = emptyConstructor,
138-
parents = enumClassRef :: Nil,
138+
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141141
body = List(ordinalDef, toStringDef) ++ registerCall
@@ -286,7 +286,9 @@ object DesugarEnums {
286286
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
287287
val ordinalDef = ordinalMethLit(tag)
288288
val toStringDef = toStringMethLit(name.toString)
289-
val impl1 = cpy.Template(impl)(body = List(ordinalDef, toStringDef) ++ registerCall)
289+
val impl1 = cpy.Template(impl)(
290+
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291+
body = List(ordinalDef, toStringDef) ++ registerCall)
290292
.withAttachment(ExtendsSingletonMirror, ())
291293
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
292294
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
450450
def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
451451
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
452452
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)
453+
def scalaRuntimeDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.runtime), name)
453454
def scalaUnit(implicit src: SourceFile): Select = scalaDot(tpnme.Unit)
454455
def scalaAny(implicit src: SourceFile): Select = scalaDot(tpnme.Any)
455456
def javaDotLangDot(name: Name)(implicit src: SourceFile): Select = Select(Select(Ident(nme.java), nme.lang), name)

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ trait ConstraintHandling[AbstractContext] {
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302302
* of all common base types, provided the result is a subtype of `bound`.
303+
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
304+
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
303305
*
304306
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
305307
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -309,26 +311,48 @@ trait ConstraintHandling[AbstractContext] {
309311
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
310312
* as those could leak the annotation to users (see run/inferred-repeated-result).
311313
*/
312-
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
313-
def widenOr(tp: Type) = {
314+
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
315+
316+
def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
317+
318+
def dropRestricted(tp: Type): Type = tp.dealias match
319+
case tpd @ AndType(tp1, tp2) =>
320+
if isRestricted(tp1) then tp2
321+
else if isRestricted(tp2) then tp1
322+
else
323+
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
324+
if tpw ne tpd then tpw else tp
325+
case _ =>
326+
tp
327+
328+
def widenRestricted(tp: Type) =
329+
val tpw = dropRestricted(tp)
330+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
331+
332+
def widenOr(tp: Type) =
314333
val tpw = tp.widenUnion
315334
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
316-
}
317-
def widenSingle(tp: Type) = {
335+
336+
def widenSingle(tp: Type) =
318337
val tpw = tp.widenSingletons
319338
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
320-
}
339+
321340
def isSingleton(tp: Type): Boolean = tp match
322341
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
323342
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
343+
324344
val wideInst =
325-
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
345+
if isSingleton(bound) then inst
346+
else /*widenRestricted*/(widenOr(widenSingle(inst)))
347+
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
348+
// in `Namer`. It's left in here in case we want to generalize the scheme to other
349+
// "protected inheritance" classes.
326350
wideInst match
327351
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
328352
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
329353
case _ =>
330354
wideInst.dropRepeatedAnnot
331-
}
355+
end widenInferred
332356

333357
/** The instance type of `param` in the current constraint (which contains `param`).
334358
* If `fromBelow` is true, the instance type is the lub of the parameter's

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ class Definitions {
639639
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
640640
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
641641

642+
@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
642643
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
643644
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
644645
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ object Flags {
438438
* TODO: Should check that FromStartFlags do not change in completion
439439
*/
440440
val FromStartFlags: FlagSet = commonFlags(
441-
Module, Package, Deferred, Method, Case,
441+
Module, Package, Deferred, Method, Case, Enum,
442442
HigherKinded, Param, ParamAccessor,
443443
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
444444
OuterOrCovariant, LabelOrContravariant, CaseAccessor,

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,14 @@ object StdNames {
357357
val CAP: N = "CAP"
358358
val Constant: N = "Constant"
359359
val ConstantType: N = "ConstantType"
360-
val doubleHash: N = "doubleHash"
360+
val EnumValue: N = "EnumValue"
361361
val ExistentialTypeTree: N = "ExistentialTypeTree"
362362
val Flag : N = "Flag"
363363
val floatHash: N = "floatHash"
364364
val Ident: N = "Ident"
365365
val Import: N = "Import"
366366
val Literal: N = "Literal"
367367
val LiteralAnnotArg: N = "LiteralAnnotArg"
368-
val longHash: N = "longHash"
369368
val MatchCase: N = "MatchCase"
370369
val MirroredElemTypes: N = "MirroredElemTypes"
371370
val MirroredElemLabels: N = "MirroredElemLabels"
@@ -443,6 +442,7 @@ object StdNames {
443442
val delayedInitArg: N = "delayedInit$body"
444443
val derived: N = "derived"
445444
val derives: N = "derives"
445+
val doubleHash: N = "doubleHash"
446446
val drop: N = "drop"
447447
val dynamics: N = "dynamics"
448448
val elem: N = "elem"
@@ -505,6 +505,7 @@ object StdNames {
505505
val language: N = "language"
506506
val length: N = "length"
507507
val lengthCompare: N = "lengthCompare"
508+
val longHash: N = "longHash"
508509
val macroThis : N = "_this"
509510
val macroContext : N = "c"
510511
val main: N = "main"

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,19 @@ class Namer { typer: Typer =>
14391439
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
14401440
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)
14411441

1442+
def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass
1443+
1444+
// Drop EnumValue parents from inferred types of enum constants
1445+
def dropEnumValue(tp: Type): Type = tp.dealias match
1446+
case tpd @ AndType(tp1, tp2) =>
1447+
if isEnumValue(tp1) then tp2
1448+
else if isEnumValue(tp2) then tp1
1449+
else
1450+
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
1451+
if tpw ne tpd then tpw else tp
1452+
case _ =>
1453+
tp
1454+
14421455
// Widen rhs type and eliminate `|' but keep ConstantTypes if
14431456
// definition is inline (i.e. final in Scala2) and keep module singleton types
14441457
// instead of widening to the underlying module class types.
@@ -1447,7 +1460,9 @@ class Namer { typer: Typer =>
14471460
def widenRhs(tp: Type): Type =
14481461
tp.widenTermRefExpr.simplified match
14491462
case ctp: ConstantType if isInlineVal => ctp
1450-
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
1463+
case tp =>
1464+
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
1465+
if sym.is(Enum) then dropEnumValue(tp1) else tp1
14511466

14521467
// Replace aliases to Unit by Unit itself. If we leave the alias in
14531468
// it would be erased to BoxedUnit.

docs/docs/reference/enums/desugarEnums.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ map into `case class`es or `val`s.
126126
where `n` is the ordinal number of the case in the companion object,
127127
starting from 0. The statement `$values.register(this)` registers the value
128128
as one of the `values` of the enumeration (see below). `$values` is a
129-
compiler-defined private value in the companion object.
129+
compiler-defined private value in the companion object. The anonymous class also
130+
implements the abstract `Product` methods that it inherits from `Enum`.
131+
130132

131133
It is an error if a value case refers to a type parameter of the enclosing `enum`
132134
in a type argument of `<parents>`.
@@ -178,6 +180,7 @@ Companion objects of enumerations that contain at least one simple case define i
178180
}
179181
```
180182

183+
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
181184
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.
182185

183186
### Scopes for Enum Cases

docs/docs/reference/enums/enums.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ If you want to use the Scala-defined enums as Java enums, you can do so by exten
9595
enum Color extends java.lang.Enum[Color] { case Red, Green, Blue }
9696
```
9797

98-
The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
98+
The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
9999
There is no need to provide constructor arguments (as defined in the Java API docs) to `java.lang.Enum` when extending it – the compiler will generate them automatically.
100100

101101
After defining `Color` like that, you can use it like you would a Java enum:
@@ -116,7 +116,7 @@ This trait defines a single public method, `ordinal`:
116116
package scala
117117

118118
/** A base trait of all enum classes */
119-
trait Enum {
119+
trait Enum extends Product with Serializable {
120120

121121
/** A number uniquely identifying a case of an enum */
122122
def ordinal: Int
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package scala
2+
3+
/** A base trait of all enum classes */
4+
trait Enum extends Product, Serializable:
5+
6+
/** A number uniquely identifying a case of an enum */
7+
def ordinal: Int
8+
protected def $ordinal: Int
9+

0 commit comments

Comments
 (0)