diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index ef863a67a..a490707bf 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1580,6 +1580,10 @@ def cast(obj): def __init__(self, value, shape=None, *, src_loc_at=0): # We deliberately do not call Value.__init__ here. + if isinstance(value, Enum): + if shape is None: + shape = Shape.cast(type(value)) + value = value.value value = int(operator.index(value)) if shape is None: shape = Shape(bits_for(value), signed=value < 0) diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index a7ba972b4..d74ad773d 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -502,6 +502,20 @@ def test_hash(self): with self.assertRaises(TypeError): hash(Const(0)) + def test_enum(self): + e1 = Const(UnsignedEnum.FOO) + self.assertIsInstance(e1, Const) + self.assertEqual(e1.shape(), unsigned(2)) + e2 = Const(SignedEnum.FOO) + self.assertIsInstance(e2, Const) + self.assertEqual(e2.shape(), signed(2)) + e3 = Const(TypedEnum.FOO) + self.assertIsInstance(e3, Const) + self.assertEqual(e3.shape(), unsigned(2)) + e4 = Const(UnsignedEnum.FOO, 4) + self.assertIsInstance(e4, Const) + self.assertEqual(e4.shape(), unsigned(4)) + def test_shape_castable(self): class MockConstValue(ValueCastable): def __init__(self, value):