diff --git a/amaranth/lib/enum.py b/amaranth/lib/enum.py index d6463c3bd..d9d415b5a 100644 --- a/amaranth/lib/enum.py +++ b/amaranth/lib/enum.py @@ -234,16 +234,16 @@ def __init__(self, enum, target): .format(target)) from e if cast_target.shape() != enum.as_shape(): raise TypeError("EnumView target must have the same shape as the enum") - self.enum = enum - self.target = cast_target + self.__enum = enum + self.__target = cast_target def shape(self): """Returns the underlying enum type.""" - return self.enum + return self.__enum def as_value(self): """Returns the underlying value.""" - return self.target + return self.__target def eq(self, other): """Assign to the underlying value. @@ -296,21 +296,25 @@ def __eq__(self, other): :class:`Value` The result of the equality comparison, as a single-bit value. """ - if isinstance(other, self.enum): - other = self.enum(Value.cast(other)) - if not isinstance(other, EnumView) or other.enum is not self.enum: - raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") - return self.target == other.target + enum_cls = self.shape() + if isinstance(other, enum_cls): + other = enum_cls(Value.cast(other)) + if not isinstance(other, EnumView) or other.shape() is not enum_cls: + raise TypeError("an EnumView can only be compared to value or other EnumView of " + "the same enum type") + return self.as_value() == other.as_value() def __ne__(self, other): - if isinstance(other, self.enum): - other = self.enum(Value.cast(other)) - if not isinstance(other, EnumView) or other.enum is not self.enum: - raise TypeError("an EnumView can only be compared to value or other EnumView of the same enum type") - return self.target != other.target + enum_cls = self.shape() + if isinstance(other, enum_cls): + other = enum_cls(Value.cast(other)) + if not isinstance(other, EnumView) or other.shape() is not enum_cls: + raise TypeError("an EnumView can only be compared to value or other EnumView of " + "the same enum type") + return self.as_value() != other.as_value() def __repr__(self): - return f"{type(self).__qualname__}({self.enum.__qualname__}, {self.target!r})" + return f"{type(self).__qualname__}({self.__enum.__qualname__}, {self.__target!r})" class FlagView(EnumView): @@ -330,21 +334,23 @@ def __invert__(self): ------- :class:`FlagView` """ - if hasattr(self.enum, "_boundary_") and self.enum._boundary_ in (EJECT, KEEP): - return self.enum._amaranth_view_class_(self.enum, ~self.target) + enum_cls = self.shape() + if hasattr(enum_cls, "_boundary_") and enum_cls._boundary_ in (EJECT, KEEP): + return enum_cls._amaranth_view_class_(enum_cls, ~self.as_value()) else: singles_mask = 0 - for flag in self.enum: + for flag in enum_cls: if (flag.value & (flag.value - 1)) == 0: singles_mask |= flag.value - return self.enum._amaranth_view_class_(self.enum, ~self.target & singles_mask) + return enum_cls._amaranth_view_class_(enum_cls, ~self.as_value() & singles_mask) def __bitop(self, other, op): - if isinstance(other, self.enum): - other = self.enum(Value.cast(other)) - if not isinstance(other, FlagView) or other.enum is not self.enum: + enum_cls = self.shape() + if isinstance(other, enum_cls): + other = enum_cls(Value.cast(other)) + if not isinstance(other, FlagView) or other.shape() is not enum_cls: raise TypeError("a FlagView can only perform bitwise operation with a value or other FlagView of the same enum type") - return self.enum._amaranth_view_class_(self.enum, op(self.target, other.target)) + return enum_cls._amaranth_view_class_(enum_cls, op(self.as_value(), other.as_value())) def __and__(self, other): """Performs a bitwise AND and returns another :class:`FlagView`.