From a7c649e9d315d91d8a8380bb7ed17cfcf780a08d Mon Sep 17 00:00:00 2001 From: Wanda Date: Tue, 27 Feb 2024 15:03:48 +0100 Subject: [PATCH] hdl._dsl: Change FSM codegen to avoid mutating AST nodes. Fixes #1066. --- amaranth/hdl/_ast.py | 5 ++ amaranth/hdl/_dsl.py | 104 +++++++++++++++++++++++++++++++----------- tests/test_hdl_dsl.py | 41 ++++++++++------- 3 files changed, 108 insertions(+), 42 deletions(-) diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index cee94b8cc..c9353f759 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0): return Property("cover", test, name=name, src_loc_at=src_loc_at+1) +class _LateBoundStatement(Statement): + def resolve(self): + raise NotImplementedError # :nocov: + + @final class Switch(Statement): def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}): diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 0201ea7c0..506ba8df2 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -9,6 +9,7 @@ from ..utils import bits_for from .. import tracer from ._ast import * +from ._ast import _StatementList, _LateBoundStatement, Property from ._ir import * from ._cd import * from ._xfrm import * @@ -146,16 +147,51 @@ def helper(*args, **kwds): return decorator +class FSMNextStatement(_LateBoundStatement): + def __init__(self, ctrl_data, state, *, src_loc_at=0): + self.ctrl_data = ctrl_data + self.state = state + super().__init__(src_loc_at=1 + src_loc_at) + + def resolve(self): + return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state]) + + class FSM: - def __init__(self, state, encoding, decoding): - self.state = state - self.encoding = encoding - self.decoding = decoding + def __init__(self, data): + self._data = data + self.encoding = data["encoding"] + self.decoding = data["decoding"] def ongoing(self, name): if name not in self.encoding: self.encoding[name] = len(self.encoding) - return Operator("==", [self.state, self.encoding[name]], src_loc_at=0) + fsm_name = self._data["name"] + self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}") + return self._data["ongoing"][name] + + +def resolve_statement(stmt): + if isinstance(stmt, _LateBoundStatement): + return resolve_statement(stmt.resolve()) + elif isinstance(stmt, Switch): + return Switch( + test=stmt.test, + cases=OrderedDict( + (patterns, resolve_statements(stmts)) + for patterns, stmts in stmt.cases.items() + ), + src_loc=stmt.src_loc, + case_src_locs=stmt.case_src_locs, + ) + elif isinstance(stmt, (Assign, Property)): + return stmt + else: + assert False # :nocov: + + +def resolve_statements(stmts): + return _StatementList(resolve_statement(stmt) for stmt in stmts) class Module(_ModuleBuilderRoot, Elaboratable): @@ -172,6 +208,7 @@ def __init__(self): self._statements = {} self._ctrl_context = None self._ctrl_stack = [] + self._top_comb_statements = _StatementList() self._driving = SignalDict() self._named_submodules = {} @@ -391,17 +428,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None): init = reset fsm_data = self._set_ctrl("FSM", { "name": name, - "signal": Signal(name=f"{name}_state", src_loc_at=2), "init": init, "domain": domain, "encoding": OrderedDict(), "decoding": OrderedDict(), + "ongoing": {}, "states": OrderedDict(), "src_loc": tracer.get_src_loc(src_loc_at=1), "state_src_locs": {}, }) - self._generated[name] = fsm = \ - FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"]) + self._generated[name] = fsm = FSM(fsm_data) try: self._ctrl_context = "FSM" self.domain._depth += 1 @@ -414,6 +450,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None): self.domain._depth -= 1 self._ctrl_context = None self._pop_ctrl() + fsm.state = fsm_data["signal"] @contextmanager def State(self, name): @@ -423,7 +460,9 @@ def State(self, name): if name in fsm_data["states"]: raise NameError(f"FSM state '{name}' is already defined") if name not in fsm_data["encoding"]: + fsm_name = fsm_data["name"] fsm_data["encoding"][name] = len(fsm_data["encoding"]) + fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}") try: _outer_case, self._statements = self._statements, {} self._ctrl_context = None @@ -445,9 +484,11 @@ def next(self, name): for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)): if ctrl_name == "FSM": if name not in ctrl_data["encoding"]: + fsm_name = ctrl_data["name"] ctrl_data["encoding"][name] = len(ctrl_data["encoding"]) + ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}") self._add_statement( - assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])], + assigns=[FSMNextStatement(ctrl_data, name)], domain=ctrl_data["domain"], depth=len(self._ctrl_stack)) return @@ -500,19 +541,25 @@ def _pop_ctrl(self): src_loc=src_loc, case_src_locs=switch_case_src_locs)) if name == "FSM": - fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \ - data["signal"], data["init"], data["encoding"], data["decoding"], data["states"] + fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \ + data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"] fsm_state_src_locs = data["state_src_locs"] if not fsm_states: + data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2) return - fsm_signal.width = bits_for(len(fsm_encoding) - 1) if fsm_init is None: - fsm_signal.init = fsm_encoding[next(iter(fsm_states))] + init = fsm_encoding[next(iter(fsm_states))] else: - fsm_signal.init = fsm_encoding[fsm_init] + init = fsm_encoding[fsm_init] # The FSM is encoded such that the state with encoding 0 is always the init state. fsm_decoding.update((n, s) for s, n in fsm_encoding.items()) - fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}" + data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init, + name=f"{fsm_name}_state", src_loc_at=2, + decoder=lambda n: f"{fsm_decoding[n]}/{n}") + + for name, sig in fsm_ongoing.items(): + self._top_comb_statements.append( + sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0))) domains = set() for stmts in fsm_states.values(): @@ -533,20 +580,21 @@ def _add_statement(self, assigns, domain, depth): self._pop_ctrl() for stmt in Statement.cast(assigns): - if not isinstance(stmt, (Assign, Property)): + if not isinstance(stmt, (Assign, Property, _LateBoundStatement)): raise SyntaxError( f"Only assignments and property checks may be appended to d.{domain}") stmt._MustUse__used = True - for signal in stmt._lhs_signals(): - if signal not in self._driving: - self._driving[signal] = domain - elif self._driving[signal] != domain: - cd_curr = self._driving[signal] - raise SyntaxError( - f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is " - f"already driven from d.{cd_curr}") + if isinstance(stmt, Assign): + for signal in stmt._lhs_signals(): + if signal not in self._driving: + self._driving[signal] = domain + elif self._driving[signal] != domain: + cd_curr = self._driving[signal] + raise SyntaxError( + f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is " + f"already driven from d.{cd_curr}") self._statements.setdefault(domain, []).append(stmt) @@ -586,9 +634,13 @@ def elaborate(self, platform): for submodule, src_loc in self._anon_submodules: fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc) for domain, statements in self._statements.items(): + statements = resolve_statements(statements) fragment.add_statements(domain, statements) - for signal, domain in self._driving.items(): - fragment.add_driver(signal, domain) + for signal in statements._lhs_signals(): + fragment.add_driver(signal, domain) + fragment.add_statements("comb", self._top_comb_statements) + for signal in self._top_comb_statements._lhs_signals(): + fragment.add_driver(signal, "comb") fragment.add_domains(self._domains.values()) fragment.generated.update(self._generated) return fragment diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index 9d03c88a8..e4fc85174 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -593,8 +593,9 @@ def test_FSM_basic(self): m.d.sync += b.eq(~b) with m.If(c): m.next = "FIRST" - m._flush() - self.assertRepr(m._statements["comb"], """ + + frag = m.elaborate(platform=None) + self.assertRepr(frag.statements["comb"], """ ( (switch (sig fsm_state) (case 0 @@ -602,9 +603,11 @@ def test_FSM_basic(self): ) (case 1 ) ) + (eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0))) + (eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1))) ) """) - self.assertRepr(m._statements["sync"], """ + self.assertRepr(frag.statements["sync"], """ ( (switch (sig fsm_state) (case 0 @@ -620,13 +623,13 @@ def test_FSM_basic(self): ) ) """) - self.assertEqual({repr(k): v for k, v in m._driving.items()}, { + self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, { "(sig a)": "comb", "(sig fsm_state)": "sync", "(sig b)": "sync", + "(sig fsm_ongoing_FIRST)": "comb", + "(sig fsm_ongoing_SECOND)": "comb", }) - - frag = m.elaborate(platform=None) fsm = frag.find_generated("fsm") self.assertIsInstance(fsm.state, Signal) self.assertEqual(fsm.encoding, OrderedDict({ @@ -647,8 +650,8 @@ def test_FSM_init(self): m.next = "SECOND" with m.State("SECOND"): m.next = "FIRST" - m._flush() - self.assertRepr(m._statements["comb"], """ + frag = m.elaborate(platform=None) + self.assertRepr(frag.statements["comb"], """ ( (switch (sig fsm_state) (case 0 @@ -656,9 +659,11 @@ def test_FSM_init(self): ) (case 1 ) ) + (eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0))) + (eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1))) ) """) - self.assertRepr(m._statements["sync"], """ + self.assertRepr(frag.statements["sync"], """ ( (switch (sig fsm_state) (case 0 @@ -683,8 +688,8 @@ def test_FSM_reset(self): m.next = "SECOND" with m.State("SECOND"): m.next = "FIRST" - m._flush() - self.assertRepr(m._statements["comb"], """ + frag = m.elaborate(platform=None) + self.assertRepr(frag.statements["comb"], """ ( (switch (sig fsm_state) (case 0 @@ -692,9 +697,11 @@ def test_FSM_reset(self): ) (case 1 ) ) + (eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0))) + (eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1))) ) """) - self.assertRepr(m._statements["sync"], """ + self.assertRepr(frag.statements["sync"], """ ( (switch (sig fsm_state) (case 0 @@ -731,13 +738,15 @@ def test_FSM_ongoing(self): m.d.comb += a.eq(fsm.ongoing("FIRST")) with m.State("SECOND"): pass - m._flush() + frag = m.elaborate(platform=None) self.assertEqual(m._generated["fsm"].state.init, 1) self.maxDiff = 10000 - self.assertRepr(m._statements["comb"], """ + self.assertRepr(frag.statements["comb"], """ ( - (eq (sig b) (== (sig fsm_state) (const 1'd0))) - (eq (sig a) (== (sig fsm_state) (const 1'd1))) + (eq (sig b) (sig fsm_ongoing_SECOND)) + (eq (sig a) (sig fsm_ongoing_FIRST)) + (eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0))) + (eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1))) ) """)