Skip to content

Commit a7c649e

Browse files
committed
hdl._dsl: Change FSM codegen to avoid mutating AST nodes.
Fixes #1066.
1 parent f524dd0 commit a7c649e

File tree

3 files changed

+108
-42
lines changed

3 files changed

+108
-42
lines changed

amaranth/hdl/_ast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0):
23042304
return Property("cover", test, name=name, src_loc_at=src_loc_at+1)
23052305

23062306

2307+
class _LateBoundStatement(Statement):
2308+
def resolve(self):
2309+
raise NotImplementedError # :nocov:
2310+
2311+
23072312
@final
23082313
class Switch(Statement):
23092314
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):

amaranth/hdl/_dsl.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..utils import bits_for
1010
from .. import tracer
1111
from ._ast import *
12+
from ._ast import _StatementList, _LateBoundStatement, Property
1213
from ._ir import *
1314
from ._cd import *
1415
from ._xfrm import *
@@ -146,16 +147,51 @@ def helper(*args, **kwds):
146147
return decorator
147148

148149

150+
class FSMNextStatement(_LateBoundStatement):
151+
def __init__(self, ctrl_data, state, *, src_loc_at=0):
152+
self.ctrl_data = ctrl_data
153+
self.state = state
154+
super().__init__(src_loc_at=1 + src_loc_at)
155+
156+
def resolve(self):
157+
return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state])
158+
159+
149160
class FSM:
150-
def __init__(self, state, encoding, decoding):
151-
self.state = state
152-
self.encoding = encoding
153-
self.decoding = decoding
161+
def __init__(self, data):
162+
self._data = data
163+
self.encoding = data["encoding"]
164+
self.decoding = data["decoding"]
154165

155166
def ongoing(self, name):
156167
if name not in self.encoding:
157168
self.encoding[name] = len(self.encoding)
158-
return Operator("==", [self.state, self.encoding[name]], src_loc_at=0)
169+
fsm_name = self._data["name"]
170+
self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
171+
return self._data["ongoing"][name]
172+
173+
174+
def resolve_statement(stmt):
175+
if isinstance(stmt, _LateBoundStatement):
176+
return resolve_statement(stmt.resolve())
177+
elif isinstance(stmt, Switch):
178+
return Switch(
179+
test=stmt.test,
180+
cases=OrderedDict(
181+
(patterns, resolve_statements(stmts))
182+
for patterns, stmts in stmt.cases.items()
183+
),
184+
src_loc=stmt.src_loc,
185+
case_src_locs=stmt.case_src_locs,
186+
)
187+
elif isinstance(stmt, (Assign, Property)):
188+
return stmt
189+
else:
190+
assert False # :nocov:
191+
192+
193+
def resolve_statements(stmts):
194+
return _StatementList(resolve_statement(stmt) for stmt in stmts)
159195

160196

161197
class Module(_ModuleBuilderRoot, Elaboratable):
@@ -172,6 +208,7 @@ def __init__(self):
172208
self._statements = {}
173209
self._ctrl_context = None
174210
self._ctrl_stack = []
211+
self._top_comb_statements = _StatementList()
175212

176213
self._driving = SignalDict()
177214
self._named_submodules = {}
@@ -391,17 +428,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
391428
init = reset
392429
fsm_data = self._set_ctrl("FSM", {
393430
"name": name,
394-
"signal": Signal(name=f"{name}_state", src_loc_at=2),
395431
"init": init,
396432
"domain": domain,
397433
"encoding": OrderedDict(),
398434
"decoding": OrderedDict(),
435+
"ongoing": {},
399436
"states": OrderedDict(),
400437
"src_loc": tracer.get_src_loc(src_loc_at=1),
401438
"state_src_locs": {},
402439
})
403-
self._generated[name] = fsm = \
404-
FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
440+
self._generated[name] = fsm = FSM(fsm_data)
405441
try:
406442
self._ctrl_context = "FSM"
407443
self.domain._depth += 1
@@ -414,6 +450,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
414450
self.domain._depth -= 1
415451
self._ctrl_context = None
416452
self._pop_ctrl()
453+
fsm.state = fsm_data["signal"]
417454

418455
@contextmanager
419456
def State(self, name):
@@ -423,7 +460,9 @@ def State(self, name):
423460
if name in fsm_data["states"]:
424461
raise NameError(f"FSM state '{name}' is already defined")
425462
if name not in fsm_data["encoding"]:
463+
fsm_name = fsm_data["name"]
426464
fsm_data["encoding"][name] = len(fsm_data["encoding"])
465+
fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
427466
try:
428467
_outer_case, self._statements = self._statements, {}
429468
self._ctrl_context = None
@@ -445,9 +484,11 @@ def next(self, name):
445484
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
446485
if ctrl_name == "FSM":
447486
if name not in ctrl_data["encoding"]:
487+
fsm_name = ctrl_data["name"]
448488
ctrl_data["encoding"][name] = len(ctrl_data["encoding"])
489+
ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
449490
self._add_statement(
450-
assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])],
491+
assigns=[FSMNextStatement(ctrl_data, name)],
451492
domain=ctrl_data["domain"],
452493
depth=len(self._ctrl_stack))
453494
return
@@ -500,19 +541,25 @@ def _pop_ctrl(self):
500541
src_loc=src_loc, case_src_locs=switch_case_src_locs))
501542

502543
if name == "FSM":
503-
fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \
504-
data["signal"], data["init"], data["encoding"], data["decoding"], data["states"]
544+
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
545+
data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"]
505546
fsm_state_src_locs = data["state_src_locs"]
506547
if not fsm_states:
548+
data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2)
507549
return
508-
fsm_signal.width = bits_for(len(fsm_encoding) - 1)
509550
if fsm_init is None:
510-
fsm_signal.init = fsm_encoding[next(iter(fsm_states))]
551+
init = fsm_encoding[next(iter(fsm_states))]
511552
else:
512-
fsm_signal.init = fsm_encoding[fsm_init]
553+
init = fsm_encoding[fsm_init]
513554
# The FSM is encoded such that the state with encoding 0 is always the init state.
514555
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
515-
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
556+
data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init,
557+
name=f"{fsm_name}_state", src_loc_at=2,
558+
decoder=lambda n: f"{fsm_decoding[n]}/{n}")
559+
560+
for name, sig in fsm_ongoing.items():
561+
self._top_comb_statements.append(
562+
sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0)))
516563

517564
domains = set()
518565
for stmts in fsm_states.values():
@@ -533,20 +580,21 @@ def _add_statement(self, assigns, domain, depth):
533580
self._pop_ctrl()
534581

535582
for stmt in Statement.cast(assigns):
536-
if not isinstance(stmt, (Assign, Property)):
583+
if not isinstance(stmt, (Assign, Property, _LateBoundStatement)):
537584
raise SyntaxError(
538585
f"Only assignments and property checks may be appended to d.{domain}")
539586

540587
stmt._MustUse__used = True
541588

542-
for signal in stmt._lhs_signals():
543-
if signal not in self._driving:
544-
self._driving[signal] = domain
545-
elif self._driving[signal] != domain:
546-
cd_curr = self._driving[signal]
547-
raise SyntaxError(
548-
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
549-
f"already driven from d.{cd_curr}")
589+
if isinstance(stmt, Assign):
590+
for signal in stmt._lhs_signals():
591+
if signal not in self._driving:
592+
self._driving[signal] = domain
593+
elif self._driving[signal] != domain:
594+
cd_curr = self._driving[signal]
595+
raise SyntaxError(
596+
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
597+
f"already driven from d.{cd_curr}")
550598

551599
self._statements.setdefault(domain, []).append(stmt)
552600

@@ -586,9 +634,13 @@ def elaborate(self, platform):
586634
for submodule, src_loc in self._anon_submodules:
587635
fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc)
588636
for domain, statements in self._statements.items():
637+
statements = resolve_statements(statements)
589638
fragment.add_statements(domain, statements)
590-
for signal, domain in self._driving.items():
591-
fragment.add_driver(signal, domain)
639+
for signal in statements._lhs_signals():
640+
fragment.add_driver(signal, domain)
641+
fragment.add_statements("comb", self._top_comb_statements)
642+
for signal in self._top_comb_statements._lhs_signals():
643+
fragment.add_driver(signal, "comb")
592644
fragment.add_domains(self._domains.values())
593645
fragment.generated.update(self._generated)
594646
return fragment

tests/test_hdl_dsl.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -593,18 +593,21 @@ def test_FSM_basic(self):
593593
m.d.sync += b.eq(~b)
594594
with m.If(c):
595595
m.next = "FIRST"
596-
m._flush()
597-
self.assertRepr(m._statements["comb"], """
596+
597+
frag = m.elaborate(platform=None)
598+
self.assertRepr(frag.statements["comb"], """
598599
(
599600
(switch (sig fsm_state)
600601
(case 0
601602
(eq (sig a) (const 1'd1))
602603
)
603604
(case 1 )
604605
)
606+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
607+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
605608
)
606609
""")
607-
self.assertRepr(m._statements["sync"], """
610+
self.assertRepr(frag.statements["sync"], """
608611
(
609612
(switch (sig fsm_state)
610613
(case 0
@@ -620,13 +623,13 @@ def test_FSM_basic(self):
620623
)
621624
)
622625
""")
623-
self.assertEqual({repr(k): v for k, v in m._driving.items()}, {
626+
self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, {
624627
"(sig a)": "comb",
625628
"(sig fsm_state)": "sync",
626629
"(sig b)": "sync",
630+
"(sig fsm_ongoing_FIRST)": "comb",
631+
"(sig fsm_ongoing_SECOND)": "comb",
627632
})
628-
629-
frag = m.elaborate(platform=None)
630633
fsm = frag.find_generated("fsm")
631634
self.assertIsInstance(fsm.state, Signal)
632635
self.assertEqual(fsm.encoding, OrderedDict({
@@ -647,18 +650,20 @@ def test_FSM_init(self):
647650
m.next = "SECOND"
648651
with m.State("SECOND"):
649652
m.next = "FIRST"
650-
m._flush()
651-
self.assertRepr(m._statements["comb"], """
653+
frag = m.elaborate(platform=None)
654+
self.assertRepr(frag.statements["comb"], """
652655
(
653656
(switch (sig fsm_state)
654657
(case 0
655658
(eq (sig a) (const 1'd0))
656659
)
657660
(case 1 )
658661
)
662+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
663+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
659664
)
660665
""")
661-
self.assertRepr(m._statements["sync"], """
666+
self.assertRepr(frag.statements["sync"], """
662667
(
663668
(switch (sig fsm_state)
664669
(case 0
@@ -683,18 +688,20 @@ def test_FSM_reset(self):
683688
m.next = "SECOND"
684689
with m.State("SECOND"):
685690
m.next = "FIRST"
686-
m._flush()
687-
self.assertRepr(m._statements["comb"], """
691+
frag = m.elaborate(platform=None)
692+
self.assertRepr(frag.statements["comb"], """
688693
(
689694
(switch (sig fsm_state)
690695
(case 0
691696
(eq (sig a) (const 1'd0))
692697
)
693698
(case 1 )
694699
)
700+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
701+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
695702
)
696703
""")
697-
self.assertRepr(m._statements["sync"], """
704+
self.assertRepr(frag.statements["sync"], """
698705
(
699706
(switch (sig fsm_state)
700707
(case 0
@@ -731,13 +738,15 @@ def test_FSM_ongoing(self):
731738
m.d.comb += a.eq(fsm.ongoing("FIRST"))
732739
with m.State("SECOND"):
733740
pass
734-
m._flush()
741+
frag = m.elaborate(platform=None)
735742
self.assertEqual(m._generated["fsm"].state.init, 1)
736743
self.maxDiff = 10000
737-
self.assertRepr(m._statements["comb"], """
744+
self.assertRepr(frag.statements["comb"], """
738745
(
739-
(eq (sig b) (== (sig fsm_state) (const 1'd0)))
740-
(eq (sig a) (== (sig fsm_state) (const 1'd1)))
746+
(eq (sig b) (sig fsm_ongoing_SECOND))
747+
(eq (sig a) (sig fsm_ongoing_FIRST))
748+
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0)))
749+
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1)))
741750
)
742751
""")
743752

0 commit comments

Comments
 (0)