Skip to content

Commit 35561ea

Browse files
committed
lib.data: improve reset value handling for Union.
* Reject union initialization with more than one reset value. * Replace the reset value specified in the class definition with the one provided during initalization instead of merging.
1 parent c7ef05c commit 35561ea

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

amaranth/lib/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,15 @@ def as_shape(cls):
431431
class _Aggregate(View, metaclass=_AggregateMeta):
432432
def __init__(self, target=None, *, name=None, reset=None, reset_less=None,
433433
attrs=None, decoder=None, src_loc_at=0):
434+
if self.__class__._AggregateMeta__layout_cls is UnionLayout:
435+
if reset is not None and len(reset) > 1:
436+
raise ValueError("Reset value for at most one field can be provided for "
437+
"a union class (specified: {})"
438+
.format(", ".join(reset.keys())))
434439
if target is None and hasattr(self.__class__, "_AggregateMeta__reset"):
435440
if reset is None:
436441
reset = self.__class__._AggregateMeta__reset
437-
else:
442+
elif self.__class__._AggregateMeta__layout_cls is not UnionLayout:
438443
reset = {**self.__class__._AggregateMeta__reset, **reset}
439444
super().__init__(self.__class__, target, name=name, reset=reset, reset_less=reset_less,
440445
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)

tests/test_lib_data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,24 @@ def test_construct_signal_kwargs(self):
669669
self.assertEqual(s.attrs, {"debug": 1})
670670
self.assertEqual(s.decoder, decoder)
671671

672+
def test_construct_reset_two_wrong(self):
673+
class U(Union):
674+
a: unsigned(1)
675+
b: unsigned(2)
676+
677+
with self.assertRaisesRegex(ValueError,
678+
r"^Reset value for at most one field can be provided for a union class "
679+
r"\(specified: a, b\)$"):
680+
U(reset=dict(a=1, b=2))
681+
682+
def test_construct_reset_override(self):
683+
class U(Union):
684+
a: unsigned(1) = 1
685+
b: unsigned(2)
686+
687+
self.assertEqual(U().as_value().reset, 0b01)
688+
self.assertEqual(U(reset=dict(b=0b10)).as_value().reset, 0b10)
689+
672690

673691
# Examples from https://github.com/amaranth-lang/amaranth/issues/693
674692
class RFCExamplesTestCase(TestCase):

0 commit comments

Comments
 (0)