Skip to content

Commit 87fbced

Browse files
committed
lib.wiring: implement Signature.flatten.
1 parent f135226 commit 87fbced

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

amaranth/lib/wiring.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,33 @@ def freeze(self):
371371
self.members.freeze()
372372
return self
373373

374+
def flatten(self, obj):
375+
for name, member in self.members.items():
376+
path = (name,)
377+
value = getattr(obj, name)
378+
379+
def iter_member(value, *, path):
380+
if member.is_port:
381+
yield path, Member(member.flow, member.shape, reset=member.reset), value
382+
elif member.is_signature:
383+
for sub_path, sub_member, sub_value in member.signature.flatten(value):
384+
if member.flow == In:
385+
sub_member = sub_member.flip()
386+
yield ((*path, *sub_path), sub_member, sub_value)
387+
else:
388+
assert False # :nocov:
389+
390+
def iter_dimensions(value, dimensions, *, path):
391+
if not dimensions:
392+
yield from iter_member(value, path=path)
393+
else:
394+
dimension, *rest_of_dimensions = dimensions
395+
for index in range(dimension):
396+
yield from iter_dimensions(value[index], rest_of_dimensions,
397+
path=(path, index))
398+
399+
yield from iter_dimensions(value, dimensions=member.dimensions, path=path)
400+
374401
def is_compliant(self, obj, *, reasons=None, path=("obj",)):
375402
def check_attr_value(member, attr_value, *, path):
376403
if member.is_port:

tests/test_lib_wiring.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,36 @@ def test_freeze(self):
372372
r"^Cannot add members to a frozen signature$"):
373373
sig.members += {"b": Out(1)}
374374

375+
def assertFlattenedSignature(self, actual, expected):
376+
for (a_path, a_member, a_value), (b_path, b_member, b_value) in zip(actual, expected):
377+
self.assertEqual(a_path, b_path)
378+
self.assertEqual(a_member, b_member)
379+
self.assertIs(a_value, b_value)
380+
381+
def test_flatten(self):
382+
sig = Signature({"a": In(1), "b": Out(2).array(2)})
383+
intf = sig.create()
384+
self.assertFlattenedSignature(sig.flatten(intf), [
385+
(("a",), In(1), intf.a),
386+
((("b",), 0), Out(2), intf.b[0]),
387+
((("b",), 1), Out(2), intf.b[1])
388+
])
389+
390+
def test_flatten_sig(self):
391+
sig = Signature({
392+
"a": Out(Signature({"p": Out(1)})),
393+
"b": Out(Signature({"q": In (1)})),
394+
"c": In( Signature({"r": Out(1)})),
395+
"d": In( Signature({"s": In (1)})),
396+
})
397+
intf = sig.create()
398+
self.assertFlattenedSignature(sig.flatten(intf), [
399+
(("a", "p"), Out(1), intf.a.p),
400+
(("b", "q"), In (1), intf.b.q),
401+
(("c", "r"), Out(1), intf.c.r),
402+
(("d", "s"), In (1), intf.d.s),
403+
])
404+
375405
def assertNotCompliant(self, reason_regex, sig, obj):
376406
self.assertFalse(sig.is_compliant(obj))
377407
reasons = []

0 commit comments

Comments
 (0)