Skip to content

Commit 6e06fc0

Browse files
wanda-phiwhitequark
authored andcommitted
hdl.ir: associate statements with domains.
Fixes #1079.
1 parent 09854fa commit 6e06fc0

File tree

12 files changed

+315
-198
lines changed

12 files changed

+315
-198
lines changed

amaranth/back/rtlil.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ def __init__(self, state, rhs_compiler, lhs_compiler):
667667
self.rhs_compiler = rhs_compiler
668668
self.lhs_compiler = lhs_compiler
669669

670+
self._domain = None
670671
self._case = None
671672
self._test_cache = {}
672673
self._has_rhs = False
@@ -865,8 +866,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
865866

866867
# Register all signals driven in the current fragment. This must be done first, as it
867868
# affects further codegen; e.g. whether \sig$next signals will be generated and used.
868-
for domain, signal in fragment.iter_drivers():
869-
compiler_state.add_driven(signal, sync=domain is not None)
869+
for domain, statements in fragment.statements.items():
870+
for signal in statements._lhs_signals():
871+
compiler_state.add_driven(signal, sync=domain is not None)
870872

871873
# Transform all signals used as ports in the current fragment eagerly and outside of
872874
# any hierarchy, to make sure they get sensible (non-prefixed) names.
@@ -925,32 +927,32 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
925927
# Therefore, we translate the fragment as many times as there are independent groups
926928
# of signals (a group is a transitive closure of signals that appear together on LHS),
927929
# splitting them into many RTLIL (and thus Verilog) processes.
928-
lhs_grouper = _xfrm.LHSGroupAnalyzer()
929-
lhs_grouper.on_statements(fragment.statements)
930-
931-
for group, group_signals in lhs_grouper.groups().items():
932-
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
933-
group_stmts = lhs_group_filter(fragment.statements)
934-
935-
with module.process(name=f"$group_{group}") as process:
936-
with process.case() as case:
937-
# For every signal in comb domain, assign \sig$next to the reset value.
938-
# For every signal in sync domains, assign \sig$next to the current
939-
# value (\sig).
940-
for domain, signal in fragment.iter_drivers():
941-
if signal not in group_signals:
942-
continue
943-
if domain is None:
944-
prev_value = _ast.Const(signal.reset, signal.width)
945-
else:
946-
prev_value = signal
947-
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
948-
949-
# Convert statements into decision trees.
950-
stmt_compiler._case = case
951-
stmt_compiler._has_rhs = False
952-
stmt_compiler._wrap_assign = False
953-
stmt_compiler(group_stmts)
930+
for domain, statements in fragment.statements.items():
931+
lhs_grouper = _xfrm.LHSGroupAnalyzer()
932+
lhs_grouper.on_statements(statements)
933+
934+
for group, group_signals in lhs_grouper.groups().items():
935+
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
936+
group_stmts = lhs_group_filter(statements)
937+
938+
with module.process(name=f"$group_{group}") as process:
939+
with process.case() as case:
940+
# For every signal in comb domain, assign \sig$next to the reset value.
941+
# For every signal in sync domains, assign \sig$next to the current
942+
# value (\sig).
943+
for signal in group_signals:
944+
if domain is None:
945+
prev_value = _ast.Const(signal.reset, signal.width)
946+
else:
947+
prev_value = signal
948+
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
949+
950+
# Convert statements into decision trees.
951+
stmt_compiler._domain = domain
952+
stmt_compiler._case = case
953+
stmt_compiler._has_rhs = False
954+
stmt_compiler._wrap_assign = False
955+
stmt_compiler(group_stmts)
954956

955957
# For every driven signal in the sync domain, create a flop of appropriate type. Which type
956958
# is appropriate depends on the domain: for domains with sync reset, it is a $dff, for
@@ -998,8 +1000,8 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
9981000
# to drive it to reset value arbitrarily) or to replace them with their reset value (which
9991001
# removes valuable source location information).
10001002
driven = _ast.SignalSet()
1001-
for domain, signals in fragment.iter_drivers():
1002-
driven.update(flatten(signal._lhs_signals() for signal in signals))
1003+
for domain, statements in fragment.statements.items():
1004+
driven.update(statements._lhs_signals())
10031005
driven.update(fragment.iter_ports(dir="i"))
10041006
driven.update(fragment.iter_ports(dir="io"))
10051007
for subfragment, sub_name in fragment.subfragments:

amaranth/hdl/_ast.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,12 @@ class _StatementList(list):
17181718
def __repr__(self):
17191719
return "({})".format(" ".join(map(repr, self)))
17201720

1721+
def _lhs_signals(self):
1722+
return union((s._lhs_signals() for s in self), start=SignalSet())
1723+
1724+
def _rhs_signals(self):
1725+
return union((s._rhs_signals() for s in self), start=SignalSet())
1726+
17211727

17221728
class Statement:
17231729
def __init__(self, *, src_loc_at=0):
@@ -1849,13 +1855,10 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={})
18491855
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
18501856

18511857
def _lhs_signals(self):
1852-
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
1853-
start=SignalSet())
1854-
return signals
1858+
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
18551859

18561860
def _rhs_signals(self):
1857-
signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
1858-
start=SignalSet())
1861+
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
18591862
return self.test._rhs_signals() | signals
18601863

18611864
def __repr__(self):

amaranth/hdl/_dsl.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(self):
170170
self.submodules = _ModuleBuilderSubmodules(self)
171171
self.domains = _ModuleBuilderDomainSet(self)
172172

173-
self._statements = Statement.cast([])
173+
self._statements = {}
174174
self._ctrl_context = None
175175
self._ctrl_stack = []
176176

@@ -234,7 +234,7 @@ def If(self, cond):
234234
"src_locs": [],
235235
})
236236
try:
237-
_outer_case, self._statements = self._statements, []
237+
_outer_case, self._statements = self._statements, {}
238238
self.domain._depth += 1
239239
yield
240240
self._flush_ctrl()
@@ -254,7 +254,7 @@ def Elif(self, cond):
254254
if if_data is None or if_data["depth"] != self.domain._depth:
255255
raise SyntaxError("Elif without preceding If")
256256
try:
257-
_outer_case, self._statements = self._statements, []
257+
_outer_case, self._statements = self._statements, {}
258258
self.domain._depth += 1
259259
yield
260260
self._flush_ctrl()
@@ -273,7 +273,7 @@ def Else(self):
273273
if if_data is None or if_data["depth"] != self.domain._depth:
274274
raise SyntaxError("Else without preceding If/Elif")
275275
try:
276-
_outer_case, self._statements = self._statements, []
276+
_outer_case, self._statements = self._statements, {}
277277
self.domain._depth += 1
278278
yield
279279
self._flush_ctrl()
@@ -341,7 +341,7 @@ def Case(self, *patterns):
341341
continue
342342
new_patterns = (*new_patterns, pattern.value)
343343
try:
344-
_outer_case, self._statements = self._statements, []
344+
_outer_case, self._statements = self._statements, {}
345345
self._ctrl_context = None
346346
yield
347347
self._flush_ctrl()
@@ -364,7 +364,7 @@ def Default(self):
364364
warnings.warn("A case defined after the default case will never be active",
365365
SyntaxWarning, stacklevel=3)
366366
try:
367-
_outer_case, self._statements = self._statements, []
367+
_outer_case, self._statements = self._statements, {}
368368
self._ctrl_context = None
369369
yield
370370
self._flush_ctrl()
@@ -416,7 +416,7 @@ def State(self, name):
416416
if name not in fsm_data["encoding"]:
417417
fsm_data["encoding"][name] = len(fsm_data["encoding"])
418418
try:
419-
_outer_case, self._statements = self._statements, []
419+
_outer_case, self._statements = self._statements, {}
420420
self._ctrl_context = None
421421
yield
422422
self._flush_ctrl()
@@ -453,28 +453,42 @@ def _pop_ctrl(self):
453453
if_tests, if_bodies = data["tests"], data["bodies"]
454454
if_src_locs = data["src_locs"]
455455

456-
tests, cases = [], OrderedDict()
457-
for if_test, if_case in zip(if_tests + [None], if_bodies):
458-
if if_test is not None:
459-
if len(if_test) != 1:
460-
if_test = if_test.bool()
461-
tests.append(if_test)
456+
domains = set()
457+
for if_case in if_bodies:
458+
domains |= set(if_case)
462459

463-
if if_test is not None:
464-
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
465-
else:
466-
match = None
467-
cases[match] = if_case
460+
for domain in domains:
461+
tests, cases = [], OrderedDict()
462+
for if_test, if_case in zip(if_tests + [None], if_bodies):
463+
if if_test is not None:
464+
if len(if_test) != 1:
465+
if_test = if_test.bool()
466+
tests.append(if_test)
468467

469-
self._statements.append(Switch(Cat(tests), cases,
470-
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
468+
if if_test is not None:
469+
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
470+
else:
471+
match = None
472+
cases[match] = if_case.get(domain, [])
473+
474+
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
475+
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
471476

472477
if name == "Switch":
473478
switch_test, switch_cases = data["test"], data["cases"]
474479
switch_case_src_locs = data["case_src_locs"]
475480

476-
self._statements.append(Switch(switch_test, switch_cases,
477-
src_loc=src_loc, case_src_locs=switch_case_src_locs))
481+
domains = set()
482+
for stmts in switch_cases.values():
483+
domains |= set(stmts)
484+
485+
for domain in domains:
486+
domain_cases = OrderedDict()
487+
for patterns, stmts in switch_cases.items():
488+
domain_cases[patterns] = stmts.get(domain, [])
489+
490+
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
491+
src_loc=src_loc, case_src_locs=switch_case_src_locs))
478492

479493
if name == "FSM":
480494
fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \
@@ -490,10 +504,20 @@ def _pop_ctrl(self):
490504
# The FSM is encoded such that the state with encoding 0 is always the reset state.
491505
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
492506
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
493-
self._statements.append(Switch(fsm_signal,
494-
OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()),
495-
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
496-
for name in fsm_states}))
507+
508+
domains = set()
509+
for stmts in fsm_states.values():
510+
domains |= set(stmts)
511+
512+
for domain in domains:
513+
domain_states = OrderedDict()
514+
for state, stmts in fsm_states.items():
515+
domain_states[state] = stmts.get(domain, [])
516+
517+
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
518+
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
519+
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
520+
for name in fsm_states}))
497521

498522
def _add_statement(self, assigns, domain, depth):
499523
def domain_name(domain):
@@ -523,7 +547,7 @@ def domain_name(domain):
523547
"already driven from d.{}"
524548
.format(signal, domain_name(domain), domain_name(cd_curr)))
525549

526-
self._statements.append(stmt)
550+
self._statements.setdefault(domain, []).append(stmt)
527551

528552
def _add_submodule(self, submodule, name=None):
529553
if not hasattr(submodule, "elaborate"):
@@ -559,7 +583,8 @@ def elaborate(self, platform):
559583
fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name)
560584
for submodule in self._anon_submodules:
561585
fragment.add_subfragment(Fragment.get(submodule, platform), None)
562-
fragment.add_statements(self._statements)
586+
for domain, statements in self._statements.items():
587+
fragment.add_statements(domain, statements)
563588
for signal, domain in self._driving.items():
564589
fragment.add_driver(signal, domain)
565590
fragment.add_domains(self._domains.values())

amaranth/hdl/_ir.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .._utils import *
88
from .._unused import *
99
from ._ast import *
10+
from ._ast import _StatementList
1011
from ._cd import *
1112

1213

@@ -65,7 +66,7 @@ def get(obj, platform):
6566
def __init__(self):
6667
self.ports = SignalDict()
6768
self.drivers = OrderedDict()
68-
self.statements = []
69+
self.statements = {}
6970
self.domains = OrderedDict()
7071
self.subfragments = []
7172
self.attrs = OrderedDict()
@@ -127,10 +128,11 @@ def add_domains(self, *domains):
127128
def iter_domains(self):
128129
yield from self.domains
129130

130-
def add_statements(self, *stmts):
131+
def add_statements(self, domain, *stmts):
132+
assert domain is None or isinstance(domain, str)
131133
for stmt in Statement.cast(stmts):
132134
stmt._MustUse__used = True
133-
self.statements.append(stmt)
135+
self.statements.setdefault(domain, _StatementList()).append(stmt)
134136

135137
def add_subfragment(self, subfragment, name=None):
136138
assert isinstance(subfragment, Fragment)
@@ -166,7 +168,8 @@ def _merge_subfragment(self, subfragment):
166168
self.ports.update(subfragment.ports)
167169
for domain, signal in subfragment.iter_drivers():
168170
self.add_driver(signal, domain)
169-
self.statements += subfragment.statements
171+
for domain, statements in subfragment.statements.items():
172+
self.statements.setdefault(domain, []).extend(statements)
170173
self.subfragments += subfragment.subfragments
171174

172175
# Remove the merged subfragment.
@@ -387,9 +390,10 @@ def add_io(*sigs):
387390

388391
# Collect all signals we're driving (on LHS of statements), and signals we're using
389392
# (on RHS of statements, or in clock domains).
390-
for stmt in self.statements:
391-
add_uses(stmt._rhs_signals())
392-
add_defs(stmt._lhs_signals())
393+
for stmts in self.statements.values():
394+
for stmt in stmts:
395+
add_uses(stmt._rhs_signals())
396+
add_defs(stmt._lhs_signals())
393397

394398
for domain, _ in self.iter_sync():
395399
cd = self.domains[domain]
@@ -572,10 +576,11 @@ def add_signal_name(signal):
572576
if domain.rst is not None:
573577
add_signal_name(domain.rst)
574578

575-
for statement in self.statements:
576-
for signal in statement._lhs_signals() | statement._rhs_signals():
577-
if not isinstance(signal, (ClockSignal, ResetSignal)):
578-
add_signal_name(signal)
579+
for statements in self.statements.values():
580+
for statement in statements:
581+
for signal in statement._lhs_signals() | statement._rhs_signals():
582+
if not isinstance(signal, (ClockSignal, ResetSignal)):
583+
add_signal_name(signal)
579584

580585
return signal_names
581586

amaranth/hdl/_mem.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def elaborate(self, platform):
124124
port._MustUse__used = True
125125
if port.domain == "comb":
126126
# Asynchronous port
127-
f.add_statements(port.data.eq(self._array[port.addr]))
127+
f.add_statements(None, port.data.eq(self._array[port.addr]))
128128
f.add_driver(port.data)
129129
else:
130130
# Synchronous port
@@ -143,6 +143,7 @@ def elaborate(self, platform):
143143
cond = write_port.en & (port.addr == write_port.addr)
144144
data = Mux(cond, write_port.data, data)
145145
f.add_statements(
146+
port.domain,
146147
Switch(port.en, {
147148
1: port.data.eq(data)
148149
})
@@ -155,10 +156,10 @@ def elaborate(self, platform):
155156
offset = index * port.granularity
156157
bits = slice(offset, offset + port.granularity)
157158
write_data = self._array[port.addr][bits].eq(port.data[bits])
158-
f.add_statements(Switch(en_bit, { 1: write_data }))
159+
f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
159160
else:
160161
write_data = self._array[port.addr].eq(port.data)
161-
f.add_statements(Switch(port.en, { 1: write_data }))
162+
f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
162163
for signal in self._array:
163164
f.add_driver(signal, port.domain)
164165
return f

0 commit comments

Comments
 (0)