Skip to content

hdl._dsl: Change FSM codegen to avoid mutating AST nodes. #1164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,11 @@ def Cover(test, *, name=None, src_loc_at=0):
return Property("cover", test, name=name, src_loc_at=src_loc_at+1)


class _LateBoundStatement(Statement):
def resolve(self):
raise NotImplementedError # :nocov:


@final
class Switch(Statement):
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):
Expand Down
104 changes: 78 additions & 26 deletions amaranth/hdl/_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..utils import bits_for
from .. import tracer
from ._ast import *
from ._ast import _StatementList, _LateBoundStatement, Property
from ._ir import *
from ._cd import *
from ._xfrm import *
Expand Down Expand Up @@ -146,16 +147,51 @@ def helper(*args, **kwds):
return decorator


class FSMNextStatement(_LateBoundStatement):
def __init__(self, ctrl_data, state, *, src_loc_at=0):
self.ctrl_data = ctrl_data
self.state = state
super().__init__(src_loc_at=1 + src_loc_at)

def resolve(self):
return self.ctrl_data["signal"].eq(self.ctrl_data["encoding"][self.state])


class FSM:
def __init__(self, state, encoding, decoding):
self.state = state
self.encoding = encoding
self.decoding = decoding
def __init__(self, data):
self._data = data
self.encoding = data["encoding"]
self.decoding = data["decoding"]

def ongoing(self, name):
if name not in self.encoding:
self.encoding[name] = len(self.encoding)
return Operator("==", [self.state, self.encoding[name]], src_loc_at=0)
fsm_name = self._data["name"]
self._data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
return self._data["ongoing"][name]


def resolve_statement(stmt):
if isinstance(stmt, _LateBoundStatement):
return resolve_statement(stmt.resolve())
elif isinstance(stmt, Switch):
return Switch(
test=stmt.test,
cases=OrderedDict(
(patterns, resolve_statements(stmts))
for patterns, stmts in stmt.cases.items()
),
src_loc=stmt.src_loc,
case_src_locs=stmt.case_src_locs,
)
elif isinstance(stmt, (Assign, Property)):
return stmt
else:
assert False # :nocov:


def resolve_statements(stmts):
return _StatementList(resolve_statement(stmt) for stmt in stmts)


class Module(_ModuleBuilderRoot, Elaboratable):
Expand All @@ -172,6 +208,7 @@ def __init__(self):
self._statements = {}
self._ctrl_context = None
self._ctrl_stack = []
self._top_comb_statements = _StatementList()

self._driving = SignalDict()
self._named_submodules = {}
Expand Down Expand Up @@ -391,17 +428,16 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
init = reset
fsm_data = self._set_ctrl("FSM", {
"name": name,
"signal": Signal(name=f"{name}_state", src_loc_at=2),
"init": init,
"domain": domain,
"encoding": OrderedDict(),
"decoding": OrderedDict(),
"ongoing": {},
"states": OrderedDict(),
"src_loc": tracer.get_src_loc(src_loc_at=1),
"state_src_locs": {},
})
self._generated[name] = fsm = \
FSM(fsm_data["signal"], fsm_data["encoding"], fsm_data["decoding"])
self._generated[name] = fsm = FSM(fsm_data)
try:
self._ctrl_context = "FSM"
self.domain._depth += 1
Expand All @@ -414,6 +450,7 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
self.domain._depth -= 1
self._ctrl_context = None
self._pop_ctrl()
fsm.state = fsm_data["signal"]

@contextmanager
def State(self, name):
Expand All @@ -423,7 +460,9 @@ def State(self, name):
if name in fsm_data["states"]:
raise NameError(f"FSM state '{name}' is already defined")
if name not in fsm_data["encoding"]:
fsm_name = fsm_data["name"]
fsm_data["encoding"][name] = len(fsm_data["encoding"])
fsm_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
try:
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
Expand All @@ -445,9 +484,11 @@ def next(self, name):
for level, (ctrl_name, ctrl_data) in enumerate(reversed(self._ctrl_stack)):
if ctrl_name == "FSM":
if name not in ctrl_data["encoding"]:
fsm_name = ctrl_data["name"]
ctrl_data["encoding"][name] = len(ctrl_data["encoding"])
ctrl_data["ongoing"][name] = Signal(name=f"{fsm_name}_ongoing_{name}")
self._add_statement(
assigns=[ctrl_data["signal"].eq(ctrl_data["encoding"][name])],
assigns=[FSMNextStatement(ctrl_data, name)],
domain=ctrl_data["domain"],
depth=len(self._ctrl_stack))
return
Expand Down Expand Up @@ -500,19 +541,25 @@ def _pop_ctrl(self):
src_loc=src_loc, case_src_locs=switch_case_src_locs))

if name == "FSM":
fsm_signal, fsm_init, fsm_encoding, fsm_decoding, fsm_states = \
data["signal"], data["init"], data["encoding"], data["decoding"], data["states"]
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
data["name"], data["init"], data["encoding"], data["decoding"], data["states"], data["ongoing"]
fsm_state_src_locs = data["state_src_locs"]
if not fsm_states:
data["signal"] = Signal(0, name=f"{fsm_name}_state", src_loc_at=2)
return
fsm_signal.width = bits_for(len(fsm_encoding) - 1)
if fsm_init is None:
fsm_signal.init = fsm_encoding[next(iter(fsm_states))]
init = fsm_encoding[next(iter(fsm_states))]
else:
fsm_signal.init = fsm_encoding[fsm_init]
init = fsm_encoding[fsm_init]
# The FSM is encoded such that the state with encoding 0 is always the init state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
data["signal"] = fsm_signal = Signal(range(len(fsm_encoding)), init=init,
name=f"{fsm_name}_state", src_loc_at=2,
decoder=lambda n: f"{fsm_decoding[n]}/{n}")

for name, sig in fsm_ongoing.items():
self._top_comb_statements.append(
sig.eq(Operator("==", [fsm_signal, fsm_encoding[name]], src_loc_at=0)))

domains = set()
for stmts in fsm_states.values():
Expand All @@ -533,20 +580,21 @@ def _add_statement(self, assigns, domain, depth):
self._pop_ctrl()

for stmt in Statement.cast(assigns):
if not isinstance(stmt, (Assign, Property)):
if not isinstance(stmt, (Assign, Property, _LateBoundStatement)):
raise SyntaxError(
f"Only assignments and property checks may be appended to d.{domain}")

stmt._MustUse__used = True

for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")
if isinstance(stmt, Assign):
for signal in stmt._lhs_signals():
if signal not in self._driving:
self._driving[signal] = domain
elif self._driving[signal] != domain:
cd_curr = self._driving[signal]
raise SyntaxError(
f"Driver-driver conflict: trying to drive {signal!r} from d.{domain}, but it is "
f"already driven from d.{cd_curr}")

self._statements.setdefault(domain, []).append(stmt)

Expand Down Expand Up @@ -586,9 +634,13 @@ def elaborate(self, platform):
for submodule, src_loc in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None, src_loc=src_loc)
for domain, statements in self._statements.items():
statements = resolve_statements(statements)
fragment.add_statements(domain, statements)
for signal, domain in self._driving.items():
fragment.add_driver(signal, domain)
for signal in statements._lhs_signals():
fragment.add_driver(signal, domain)
fragment.add_statements("comb", self._top_comb_statements)
for signal in self._top_comb_statements._lhs_signals():
fragment.add_driver(signal, "comb")
fragment.add_domains(self._domains.values())
fragment.generated.update(self._generated)
return fragment
41 changes: 25 additions & 16 deletions tests/test_hdl_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,18 +593,21 @@ def test_FSM_basic(self):
m.d.sync += b.eq(~b)
with m.If(c):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """

frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
(eq (sig a) (const 1'd1))
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
Expand All @@ -620,13 +623,13 @@ def test_FSM_basic(self):
)
)
""")
self.assertEqual({repr(k): v for k, v in m._driving.items()}, {
self.assertEqual({repr(sig): k for k, v in frag.drivers.items() for sig in v}, {
"(sig a)": "comb",
"(sig fsm_state)": "sync",
"(sig b)": "sync",
"(sig fsm_ongoing_FIRST)": "comb",
"(sig fsm_ongoing_SECOND)": "comb",
})

frag = m.elaborate(platform=None)
fsm = frag.find_generated("fsm")
self.assertIsInstance(fsm.state, Signal)
self.assertEqual(fsm.encoding, OrderedDict({
Expand All @@ -647,18 +650,20 @@ def test_FSM_init(self):
m.next = "SECOND"
with m.State("SECOND"):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """
frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
(eq (sig a) (const 1'd0))
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
Expand All @@ -683,18 +688,20 @@ def test_FSM_reset(self):
m.next = "SECOND"
with m.State("SECOND"):
m.next = "FIRST"
m._flush()
self.assertRepr(m._statements["comb"], """
frag = m.elaborate(platform=None)
self.assertRepr(frag.statements["comb"], """
(
(switch (sig fsm_state)
(case 0
(eq (sig a) (const 1'd0))
)
(case 1 )
)
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd1)))
)
""")
self.assertRepr(m._statements["sync"], """
self.assertRepr(frag.statements["sync"], """
(
(switch (sig fsm_state)
(case 0
Expand Down Expand Up @@ -731,13 +738,15 @@ def test_FSM_ongoing(self):
m.d.comb += a.eq(fsm.ongoing("FIRST"))
with m.State("SECOND"):
pass
m._flush()
frag = m.elaborate(platform=None)
self.assertEqual(m._generated["fsm"].state.init, 1)
self.maxDiff = 10000
self.assertRepr(m._statements["comb"], """
self.assertRepr(frag.statements["comb"], """
(
(eq (sig b) (== (sig fsm_state) (const 1'd0)))
(eq (sig a) (== (sig fsm_state) (const 1'd1)))
(eq (sig b) (sig fsm_ongoing_SECOND))
(eq (sig a) (sig fsm_ongoing_FIRST))
(eq (sig fsm_ongoing_SECOND) (== (sig fsm_state) (const 1'd0)))
(eq (sig fsm_ongoing_FIRST) (== (sig fsm_state) (const 1'd1)))
)
""")

Expand Down