Skip to content

Commit 055d6a4

Browse files
wanda-phiwhitequark
authored andcommitted
sim: make driving parts of a signal from distinct modules possible.
Fixes (part of) #1454.
1 parent ad76186 commit 055d6a4

File tree

6 files changed

+178
-17
lines changed

6 files changed

+178
-17
lines changed

amaranth/hdl/_xfrm.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"FragmentTransformer",
1818
"TransformedElaboratable",
1919
"DomainCollector", "DomainRenamer", "DomainLowerer",
20+
"LHSMaskCollector",
2021
"ResetInserter", "EnableInserter"]
2122

2223

@@ -575,6 +576,71 @@ def on_fragment(self, fragment):
575576
return super().on_fragment(fragment)
576577

577578

579+
class LHSMaskCollector:
580+
def __init__(self):
581+
self.lhs = SignalDict()
582+
583+
def visit_stmt(self, stmt):
584+
if type(stmt) is Assign:
585+
self.visit_value(stmt.lhs, ~0)
586+
elif type(stmt) is Switch:
587+
for (_, substmt, _) in stmt.cases:
588+
self.visit_stmt(substmt)
589+
elif type(stmt) in (Property, Print):
590+
pass
591+
elif isinstance(stmt, Iterable):
592+
for substmt in stmt:
593+
self.visit_stmt(substmt)
594+
else:
595+
assert False # :nocov:
596+
597+
def visit_value(self, value, mask):
598+
if type(value) in (Signal, ClockSignal, ResetSignal):
599+
mask &= (1 << len(value)) - 1
600+
self.lhs.setdefault(value, 0)
601+
self.lhs[value] |= mask
602+
elif type(value) is Operator:
603+
assert value.operator in ("s", "u")
604+
self.visit_value(value.operands[0], mask)
605+
elif type(value) is Slice:
606+
slice_mask = (1 << value.stop) - (1 << value.start)
607+
mask <<= value.start
608+
mask &= slice_mask
609+
self.visit_value(value.value, mask)
610+
elif type(value) is Part:
611+
# Could be more accurate, but if you're relying on such details, you're not seeing
612+
# the Light of Heaven.
613+
self.visit_value(value.value, ~0)
614+
elif type(value) is Concat:
615+
for part in value.parts:
616+
self.visit_value(part, mask)
617+
mask >>= len(part)
618+
elif type(value) is SwitchValue:
619+
for (_, subvalue) in value.cases:
620+
self.visit_value(subvalue, mask)
621+
else:
622+
assert False # :nocov:
623+
624+
def chunks(self):
625+
for signal, mask in self.lhs.items():
626+
if mask == (1 << len(signal)) - 1:
627+
yield signal, 0, None
628+
else:
629+
start = 0
630+
while start < len(signal):
631+
if ((mask >> start) & 1) == 0:
632+
start += 1
633+
else:
634+
stop = start
635+
while stop < len(signal) and ((mask >> stop) & 1) == 1:
636+
stop += 1
637+
yield (signal, start, stop)
638+
start = stop
639+
640+
def masks(self):
641+
yield from self.lhs.items()
642+
643+
578644
class _ControlInserter(FragmentTransformer):
579645
def __init__(self, controls):
580646
self.src_loc = None
@@ -589,10 +655,9 @@ def on_fragment(self, fragment):
589655
for domain, statements in fragment.statements.items():
590656
if domain == "comb" or domain not in self.controls:
591657
continue
592-
signals = SignalSet()
593-
for stmt in statements:
594-
signals |= stmt._lhs_signals()
595-
self._insert_control(new_fragment, domain, signals)
658+
lhs_masks = LHSMaskCollector()
659+
lhs_masks.visit_stmt(statements)
660+
self._insert_control(new_fragment, domain, lhs_masks)
596661
return new_fragment
597662

598663
def _insert_control(self, fragment, domain, signals):
@@ -604,13 +669,20 @@ def __call__(self, value, *, src_loc_at=0):
604669

605670

606671
class ResetInserter(_ControlInserter):
607-
def _insert_control(self, fragment, domain, signals):
608-
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
672+
def _insert_control(self, fragment, domain, lhs_masks):
673+
stmts = []
674+
for signal, start, stop in lhs_masks.chunks():
675+
if signal.reset_less:
676+
continue
677+
if start == 0 and stop is None:
678+
stmts.append(signal.eq(Const(signal.init, signal.shape())))
679+
else:
680+
stmts.append(signal[start:stop].eq(Const(signal.init, signal.shape())[start:stop]))
609681
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))
610682

611683

612684
class EnableInserter(_ControlInserter):
613-
def _insert_control(self, fragment, domain, signals):
685+
def _insert_control(self, fragment, domain, _lhs_masks):
614686
if domain in fragment.statements:
615687
fragment.statements[domain] = _StatementList([Switch(
616688
self.controls[domain],

amaranth/sim/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BaseSignalState:
2323
curr = NotImplemented
2424
next = NotImplemented
2525

26-
def update(self, value):
26+
def update(self, value, mask=~0):
2727
raise NotImplementedError # :nocov:
2828

2929

amaranth/sim/_pyrtl.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ..hdl import *
77
from ..hdl._ast import SignalSet, _StatementList, Property
8-
from ..hdl._xfrm import ValueVisitor, StatementVisitor
8+
from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSMaskCollector
99
from ..hdl._mem import MemoryInstance
1010
from ._base import BaseProcess
1111
from ._pyeval import value_to_string
@@ -487,19 +487,20 @@ def __call__(self, fragment):
487487
for domain_name in domains:
488488
domain_stmts = fragment.statements.get(domain_name, _StatementList())
489489
domain_process = PyRTLProcess(is_comb=domain_name == "comb")
490-
domain_signals = domain_stmts._lhs_signals()
490+
lhs_masks = LHSMaskCollector()
491+
lhs_masks.visit_stmt(domain_stmts)
491492

492493
if isinstance(fragment, MemoryInstance):
493494
for port in fragment._read_ports:
494495
if port._domain == domain_name:
495-
domain_signals.update(port._data._lhs_signals())
496+
lhs_masks.visit_value(port._data, ~0)
496497

497498
emitter = _PythonEmitter()
498499
emitter.append(f"def run():")
499500
emitter._level += 1
500501

501502
if domain_name == "comb":
502-
for signal in domain_signals:
503+
for (signal, _) in lhs_masks.masks():
503504
signal_index = self.state.get_signal(signal)
504505
self.state.slots[signal_index].is_comb = True
505506
emitter.append(f"next_{signal_index} = {signal.init}")
@@ -533,7 +534,7 @@ def __call__(self, fragment):
533534
if domain.async_reset and domain.rst is not None:
534535
self.state.add_signal_waker(domain.rst, edge_waker(domain_process, 1))
535536

536-
for signal in domain_signals:
537+
for (signal, _) in lhs_masks.masks():
537538
signal_index = self.state.get_signal(signal)
538539
emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
539540

@@ -546,7 +547,7 @@ def __call__(self, fragment):
546547
emitter.append(f"if {rst}:")
547548
with emitter.indent():
548549
emitter.append("pass")
549-
for signal in domain_signals:
550+
for (signal, _) in lhs_masks.masks():
550551
if not signal.reset_less:
551552
signal_index = self.state.get_signal(signal)
552553
emitter.append(f"next_{signal_index} = {signal.init}")
@@ -592,9 +593,11 @@ def __call__(self, fragment):
592593

593594
lhs(port._data)(data)
594595

595-
for signal in domain_signals:
596+
for (signal, mask) in lhs_masks.masks():
597+
if signal.shape().signed and (mask & 1 << (len(signal) - 1)):
598+
mask |= -1 << len(signal)
596599
signal_index = self.state.get_signal(signal)
597-
emitter.append(f"slots[{signal_index}].update(next_{signal_index})")
600+
emitter.append(f"slots[{signal_index}].update(next_{signal_index}, {mask})")
598601

599602
# There shouldn't be any exceptions raised by the generated code, but if there are
600603
# (almost certainly due to a bug in the code generator), use this environment variable

amaranth/sim/pysim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def add_waker(self, waker):
369369
assert waker not in self.wakers
370370
self.wakers.append(waker)
371371

372-
def update(self, value):
372+
def update(self, value, mask=~0):
373+
value = (self.next & ~mask) | (value & mask)
373374
if self.next != value:
374375
self.next = value
375376
self.pending.add(self)

tests/test_hdl_xfrm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def setUp(self):
227227
self.s1 = Signal()
228228
self.s2 = Signal(init=1)
229229
self.s3 = Signal(init=1, reset_less=True)
230+
self.s4 = Signal(8, init=0x3a)
230231
self.c1 = Signal()
231232

232233
def test_reset_default(self):
@@ -281,6 +282,40 @@ def test_reset_value(self):
281282
)
282283
""")
283284

285+
def test_reset_mask(self):
286+
f = Fragment()
287+
f.add_statements("sync", self.s4[2:4].eq(0))
288+
289+
f = ResetInserter(self.c1)(f)
290+
self.assertRepr(f.statements["sync"], """
291+
(
292+
(eq (slice (sig s4) 2:4) (const 1'd0))
293+
(switch (sig c1)
294+
(case 1 (eq (slice (sig s4) 2:4) (slice (const 8'd58) 2:4)))
295+
)
296+
)
297+
""")
298+
299+
f = Fragment()
300+
f.add_statements("sync", self.s4[2:4].eq(0))
301+
f.add_statements("sync", self.s4[3:5].eq(0))
302+
f.add_statements("sync", self.s4[6:10].eq(0))
303+
304+
f = ResetInserter(self.c1)(f)
305+
self.assertRepr(f.statements["sync"], """
306+
(
307+
(eq (slice (sig s4) 2:4) (const 1'd0))
308+
(eq (slice (sig s4) 3:5) (const 1'd0))
309+
(eq (slice (sig s4) 6:8) (const 1'd0))
310+
(switch (sig c1)
311+
(case 1
312+
(eq (slice (sig s4) 2:5) (slice (const 8'd58) 2:5))
313+
(eq (slice (sig s4) 6:8) (slice (const 8'd58) 6:8))
314+
)
315+
)
316+
)
317+
""")
318+
284319
def test_reset_less(self):
285320
f = Fragment()
286321
f.add_statements("sync", self.s3.eq(0))
@@ -423,3 +458,31 @@ def test_composition(self):
423458
)
424459
)
425460
""")
461+
462+
class LHSMaskCollectorTestCase(FHDLTestCase):
463+
def test_slice(self):
464+
s = Signal(8)
465+
lhs = LHSMaskCollector()
466+
lhs.visit_value(s[2:5], ~0)
467+
self.assertEqual(lhs.lhs[s], 0x1c)
468+
469+
def test_slice_slice(self):
470+
s = Signal(8)
471+
lhs = LHSMaskCollector()
472+
lhs.visit_value(s[2:7][1:3], ~0)
473+
self.assertEqual(lhs.lhs[s], 0x18)
474+
475+
def test_slice_concat(self):
476+
s1 = Signal(8)
477+
s2 = Signal(8)
478+
lhs = LHSMaskCollector()
479+
lhs.visit_value(Cat(s1, s2)[4:11], ~0)
480+
self.assertEqual(lhs.lhs[s1], 0xf0)
481+
self.assertEqual(lhs.lhs[s2], 0x07)
482+
483+
def test_slice_part(self):
484+
s = Signal(8)
485+
idx = Signal(8)
486+
lhs = LHSMaskCollector()
487+
lhs.visit_value(s.bit_select(idx, 5)[1:3], ~0)
488+
self.assertEqual(lhs.lhs[s], 0xff)

tests/test_sim.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,28 @@ async def testbench(ctx):
14021402
with self.assertSimulation(Module(), traces=[mem1, mem2, mem3]) as sim:
14031403
sim.add_testbench(testbench)
14041404

1405+
def test_multiple_modules(self):
1406+
m = Module()
1407+
m.submodules.m1 = m1 = Module()
1408+
m.submodules.m2 = m2 = Module()
1409+
a = Signal(8)
1410+
b = Signal(8)
1411+
m1.d.comb += b[0:2].eq(a[0:2])
1412+
m1.d.comb += b[4:6].eq(a[4:6])
1413+
m2.d.comb += b[2:4].eq(a[2:4])
1414+
m2.d.comb += b[6:8].eq(a[6:8])
1415+
with self.assertSimulation(m) as sim:
1416+
async def testbench(ctx):
1417+
ctx.set(a, 0)
1418+
self.assertEqual(ctx.get(b), 0)
1419+
ctx.set(a, 0x12)
1420+
self.assertEqual(ctx.get(b), 0x12)
1421+
ctx.set(a, 0x34)
1422+
self.assertEqual(ctx.get(b), 0x34)
1423+
ctx.set(a, 0xdb)
1424+
self.assertEqual(ctx.get(b), 0xdb)
1425+
sim.add_testbench(testbench)
1426+
14051427

14061428
class SimulatorTracesTestCase(FHDLTestCase):
14071429
def assertDef(self, traces, flat_traces):

0 commit comments

Comments
 (0)