Skip to content

Commit 30fc02f

Browse files
committed
hdl._dsl: Change FSM codegen to avoid mutating AST nodes.
Fixes #1066.
1 parent f524dd0 commit 30fc02f

File tree

3 files changed

+107
-42
lines changed

3 files changed

+107
-42
lines changed

amaranth/hdl/_ast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0):
23042304
return Property("cover", test, name=name, src_loc_at=src_loc_at+1)
23052305

23062306

2307+
class _LateBoundStatement(Statement):
2308+
def resolve(self):
2309+
raise NotImplementedError
2310+
2311+
23072312
@final
23082313
class Switch(Statement):
23092314
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):

amaranth/hdl/_dsl.py

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..utils import bits_for
1010
from .. import tracer
1111
from ._ast import *
12+
from ._ast import _StatementList, _LateBoundStatement, Property
1213
from ._ir import *
1314
from ._cd import *
1415
from ._xfrm import *
@@ -146,16 +147,50 @@ def helper(*args, **kwds):
146147
return decorator
147148

148149

150+
class FSMNextStatement(_LateBoundStatement):
151+
def __init__(self, ctrl_data, state, src_loc_at=0):
152+
self.ctrl_data = ctrl_data
153+
self.state = state
154+
super().__init__(src_loc_at=1 + src_loc_at)
155+
156+
def resolve(self):
157+
return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state])
158+
159+
149160
class FSM:
150-
def __init__(self, state, encoding, decoding):
151-
self.state = state
152-
self.encoding = encoding
153-
self.decoding = decoding
161+
def __init__(self, data):
162+
self._data = data
163+
self.encoding = data["encoding"]
164+
self.decoding = data["decoding"]
154165

155166
def ongoing(self, name):
156167
if name not in self.encoding:
157168
self.encoding[name] = len(self.encoding)
158-
return Operator("==", [self.state, self.encoding[name]], src_loc_at=0)
169+
fsm_name = self._data["name"]
170+
self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
171+
return self._data["ongoing"][name]
172+
173+
174+
def resolve_statement(stmt):
175+
if isinstance(stmt, _LateBoundStatement):
176+
return resolve_statement(stmt.resolve())
177+
elif isinstance(stmt, Switch):
178+
return Switch(
179+
test=stmt.test,
180+
cases=OrderedDict(
181+
(patterns, resolve_statements(stmts))
182+
for patterns, stmts in stmt.cases.items()
183+
),
184+
src_loc=stmt.src_loc,
185+
case_src_locs=stmt.case_src_locs,
186+
)
187+
elif isinstance(stmt, (Assign, Property)):
188+
return stmt
189+
else:
190+
assert False # nocov
191+
192+
def resolve_statements(stmts):
193+
return _StatementList(resolve_statement(stmt) for stmt in stmts)
159194

160195

161196
class Module(_ModuleBuilderRoot, Elaboratable):
@@ -172,6 +207,7 @@ def __init__(self):
172207
self._statements = {}
173208
self._ctrl_context = None
174209
self._ctrl_stack = []
210+
self._top_comb_statements = _StatementList()
175211

176212
self._driving = SignalDict()
177213
self._named_submodules = {}
@@ -391,17 +427,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
391427
init = reset
392428
fsm_data = self._set_ctrl("FSM", {
393429
"name": name,
394-
"signal": Signal(name=f"{name}_state", src_loc_at=2),
395430
"init": init,
396431
"domain": domain,
397432
"encoding": OrderedDict(),
398433
"decoding": OrderedDict(),
434+
"ongoing": {},
399435
"states": OrderedDict(),
400436
"src_loc": tracer.get_src_loc(src_loc_at=1),
401437
"state_src_locs": {},
402438
})
403-
self._generated[name] = fsm = \
404-
FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
439+
self._generated[name] = fsm = FSM(fsm_data)
405440
try:
406441
self._ctrl_context = "FSM"
407442
self.domain._depth += 1
@@ -414,6 +449,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
414449
self.domain._depth -= 1
415450
self._ctrl_context = None
416451
self._pop_ctrl()
452+
fsm.state = fsm_data["signal"]
417453

418454
@contextmanager
419455
def State(self, name):
@@ -423,7 +459,9 @@ def State(self, name):
423459
if name in fsm_data["states"]:
424460
raise NameError(f"FSM state '{name}' is already defined")
425461
if name not in fsm_data["encoding"]:
462+
fsm_name = fsm_data["name"]
426463
fsm_data["encoding"][name] = len(fsm_data["encoding"])
464+
fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
427465
try:
428466
_outer_case, self._statements = self._statements, {}
429467
self._ctrl_context = None
@@ -445,9 +483,11 @@ def next(self, name):
445483
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
446484
if ctrl_name == "FSM":
447485
if name not in ctrl_data["encoding"]:
486+
fsm_name = ctrl_data["name"]
448487
ctrl_data["encoding"][name] = len(ctrl_data["encoding"])
488+
ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
449489
self._add_statement(
450-
assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])],
490+
assigns=[FSMNextStatement(ctrl_data, name)],
451491
domain=ctrl_data["domain"],
452492
depth=len(self._ctrl_stack))
453493
return
@@ -500,19 +540,25 @@ def _pop_ctrl(self):
500540
src_loc=src_loc, case_src_locs=switch_case_src_locs))
501541

502542
if name == "FSM":
503-
fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \
504-
data["signal"], data["init"], data["encoding"], data["decoding"], data["states"]
543+
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
544+
data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"]
505545
fsm_state_src_locs = data["state_src_locs"]
506546
if not fsm_states:
547+
data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2)
507548
return
508-
fsm_signal.width = bits_for(len(fsm_encoding) - 1)
509549
if fsm_init is None:
510-
fsm_signal.init = fsm_encoding[next(iter(fsm_states))]
550+
init = fsm_encoding[next(iter(fsm_states))]
511551
else:
512-
fsm_signal.init = fsm_encoding[fsm_init]
552+
init = fsm_encoding[fsm_init]
513553
# The FSM is encoded such that the state with encoding 0 is always the init state.
514554
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
515-
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
555+
data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init,
556+
name=f"{fsm_name}_state", src_loc_at=2,
557+
decoder=lambda n: f"{fsm_decoding[n]}/{n}")
558+
559+
for name, sig in fsm_ongoing.items():
560+
self._top_comb_statements.append(
561+
sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0)))
516562

517563
domains = set()
518564
for stmts in fsm_states.values():
@@ -533,20 +579,21 @@ def _add_statement(self, assigns, domain, depth):
533579
self._pop_ctrl()
534580

535581
for stmt in Statement.cast(assigns):
536-
if not isinstance(stmt, (Assign, Property)):
582+
if not isinstance(stmt, (Assign, Property, _LateBoundStatement)):
537583
raise SyntaxError(
538584
f"Only assignments and property checks may be appended to d.{domain}")
539585

540586
stmt._MustUse__used = True
541587

542-
for signal in stmt._lhs_signals():
543-
if signal not in self._driving:
544-
self._driving[signal] = domain
545-
elif self._driving[signal] != domain:
546-
cd_curr = self._driving[signal]
547-
raise SyntaxError(
548-
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
549-
f"already driven from d.{cd_curr}")
588+
if isinstance(stmt, Assign):
589+
for signal in stmt._lhs_signals():
590+
if signal not in self._driving:
591+
self._driving[signal] = domain
592+
elif self._driving[signal] != domain:
593+
cd_curr = self._driving[signal]
594+
raise SyntaxError(
595+
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
596+
f"already driven from d.{cd_curr}")
550597

551598
self._statements.setdefault(domain, []).append(stmt)
552599

@@ -586,9 +633,13 @@ def elaborate(self, platform):
586633
for submodule, src_loc in self._anon_submodules:
587634
fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc)
588635
for domain, statements in self._statements.items():
636+
statements = resolve_statements(statements)
589637
fragment.add_statements(domain, statements)
590-
for signal, domain in self._driving.items():
591-
fragment.add_driver(signal, domain)
638+
for signal in statements._lhs_signals():
639+
fragment.add_driver(signal, domain)
640+
fragment.add_statements("comb", self._top_comb_statements)
641+
for signal in self._top_comb_statements._lhs_signals():
642+
fragment.add_driver(signal, "comb")
592643
fragment.add_domains(self._domains.values())
593644
fragment.generated.update(self._generated)
594645
return fragment

tests/test_hdl_dsl.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -593,18 +593,21 @@ def test_FSM_basic(self):
593593
m.d.sync += b.eq(~b)
594594
with m.If(c):
595595
m.next = "FIRST"
596-
m._flush()
597-
self.assertRepr(m._statements["comb"], """
596+
597+
frag = m.elaborate(platform=None)
598+
self.assertRepr(frag.statements["comb"], """
598599
(
599600
(switch (sig fsm_state)
600601
(case 0
601602
(eq (sig a) (const 1'd1))
602603
)
603604
(case 1 )
604605
)
606+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
607+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
605608
)
606609
""")
607-
self.assertRepr(m._statements["sync"], """
610+
self.assertRepr(frag.statements["sync"], """
608611
(
609612
(switch (sig fsm_state)
610613
(case 0
@@ -620,13 +623,13 @@ def test_FSM_basic(self):
620623
)
621624
)
622625
""")
623-
self.assertEqual({repr(k): v for k, v in m._driving.items()}, {
626+
self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, {
624627
"(sig a)": "comb",
625628
"(sig fsm_state)": "sync",
626629
"(sig b)": "sync",
630+
"(sig fsm_ongoing_FIRST)": "comb",
631+
"(sig fsm_ongoing_SECOND)": "comb",
627632
})
628-
629-
frag = m.elaborate(platform=None)
630633
fsm = frag.find_generated("fsm")
631634
self.assertIsInstance(fsm.state, Signal)
632635
self.assertEqual(fsm.encoding, OrderedDict({
@@ -647,18 +650,20 @@ def test_FSM_init(self):
647650
m.next = "SECOND"
648651
with m.State("SECOND"):
649652
m.next = "FIRST"
650-
m._flush()
651-
self.assertRepr(m._statements["comb"], """
653+
frag = m.elaborate(platform=None)
654+
self.assertRepr(frag.statements["comb"], """
652655
(
653656
(switch (sig fsm_state)
654657
(case 0
655658
(eq (sig a) (const 1'd0))
656659
)
657660
(case 1 )
658661
)
662+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
663+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
659664
)
660665
""")
661-
self.assertRepr(m._statements["sync"], """
666+
self.assertRepr(frag.statements["sync"], """
662667
(
663668
(switch (sig fsm_state)
664669
(case 0
@@ -683,18 +688,20 @@ def test_FSM_reset(self):
683688
m.next = "SECOND"
684689
with m.State("SECOND"):
685690
m.next = "FIRST"
686-
m._flush()
687-
self.assertRepr(m._statements["comb"], """
691+
frag = m.elaborate(platform=None)
692+
self.assertRepr(frag.statements["comb"], """
688693
(
689694
(switch (sig fsm_state)
690695
(case 0
691696
(eq (sig a) (const 1'd0))
692697
)
693698
(case 1 )
694699
)
700+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
701+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
695702
)
696703
""")
697-
self.assertRepr(m._statements["sync"], """
704+
self.assertRepr(frag.statements["sync"], """
698705
(
699706
(switch (sig fsm_state)
700707
(case 0
@@ -731,13 +738,15 @@ def test_FSM_ongoing(self):
731738
m.d.comb += a.eq(fsm.ongoing("FIRST"))
732739
with m.State("SECOND"):
733740
pass
734-
m._flush()
741+
frag = m.elaborate(platform=None)
735742
self.assertEqual(m._generated["fsm"].state.init, 1)
736743
self.maxDiff = 10000
737-
self.assertRepr(m._statements["comb"], """
744+
self.assertRepr(frag.statements["comb"], """
738745
(
739-
(eq (sig b) (== (sig fsm_state) (const 1'd0)))
740-
(eq (sig a) (== (sig fsm_state) (const 1'd1)))
746+
(eq (sig b) (sig fsm_ongoing_SECOND))
747+
(eq (sig a) (sig fsm_ongoing_FIRST))
748+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0)))
749+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1)))
741750
)
742751
""")
743752

0 commit comments

Comments
 (0)