Skip to content

hdl.ir: associate statements with domains. #1094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions amaranth/back/rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def __init__(self, state, rhs_compiler, lhs_compiler):
self.rhs_compiler = rhs_compiler
self.lhs_compiler = lhs_compiler

self._domain = None
self._case = None
self._test_cache = {}
self._has_rhs = False
Expand Down Expand Up @@ -865,8 +866,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):

# Register all signals driven in the current fragment. This must be done first, as it
# affects further codegen; e.g. whether \sig$next signals will be generated and used.
for domain, signal in fragment.iter_drivers():
compiler_state.add_driven(signal, sync=domain is not None)
for domain, statements in fragment.statements.items():
for signal in statements._lhs_signals():
compiler_state.add_driven(signal, sync=domain is not None)

# Transform all signals used as ports in the current fragment eagerly and outside of
# any hierarchy, to make sure they get sensible (non-prefixed) names.
Expand Down Expand Up @@ -925,32 +927,32 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# Therefore, we translate the fragment as many times as there are independent groups
# of signals (a group is a transitive closure of signals that appear together on LHS),
# splitting them into many RTLIL (and thus Verilog) processes.
lhs_grouper = _xfrm.LHSGroupAnalyzer()
lhs_grouper.on_statements(fragment.statements)

for group, group_signals in lhs_grouper.groups().items():
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
group_stmts = lhs_group_filter(fragment.statements)

with module.process(name=f"$group_{group}") as process:
with process.case() as case:
# For every signal in comb domain, assign \sig$next to the reset value.
# For every signal in sync domains, assign \sig$next to the current
# value (\sig).
for domain, signal in fragment.iter_drivers():
if signal not in group_signals:
continue
if domain is None:
prev_value = _ast.Const(signal.reset, signal.width)
else:
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))

# Convert statements into decision trees.
stmt_compiler._case = case
stmt_compiler._has_rhs = False
stmt_compiler._wrap_assign = False
stmt_compiler(group_stmts)
for domain, statements in fragment.statements.items():
lhs_grouper = _xfrm.LHSGroupAnalyzer()
lhs_grouper.on_statements(statements)

for group, group_signals in lhs_grouper.groups().items():
lhs_group_filter = _xfrm.LHSGroupFilter(group_signals)
group_stmts = lhs_group_filter(statements)

with module.process(name=f"$group_{group}") as process:
with process.case() as case:
# For every signal in comb domain, assign \sig$next to the reset value.
# For every signal in sync domains, assign \sig$next to the current
# value (\sig).
for signal in group_signals:
if domain is None:
prev_value = _ast.Const(signal.reset, signal.width)
else:
prev_value = signal
case.assign(lhs_compiler(signal), rhs_compiler(prev_value))

# Convert statements into decision trees.
stmt_compiler._domain = domain
stmt_compiler._case = case
stmt_compiler._has_rhs = False
stmt_compiler._wrap_assign = False
stmt_compiler(group_stmts)

# For every driven signal in the sync domain, create a flop of appropriate type. Which type
# is appropriate depends on the domain: for domains with sync reset, it is a $dff, for
Expand Down Expand Up @@ -998,8 +1000,8 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
# to drive it to reset value arbitrarily) or to replace them with their reset value (which
# removes valuable source location information).
driven = _ast.SignalSet()
for domain, signals in fragment.iter_drivers():
driven.update(flatten(signal._lhs_signals() for signal in signals))
for domain, statements in fragment.statements.items():
driven.update(statements._lhs_signals())
driven.update(fragment.iter_ports(dir="i"))
driven.update(fragment.iter_ports(dir="io"))
for subfragment, sub_name in fragment.subfragments:
Expand Down
13 changes: 8 additions & 5 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,12 @@ class _StatementList(list):
def __repr__(self):
return "({})".format(" ".join(map(repr, self)))

def _lhs_signals(self):
return union((s._lhs_signals() for s in self), start=SignalSet())

def _rhs_signals(self):
return union((s._rhs_signals() for s in self), start=SignalSet())


class Statement:
def __init__(self, *, src_loc_at=0):
Expand Down Expand Up @@ -1837,13 +1843,10 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={})
self.case_src_locs[new_keys] = case_src_locs[orig_keys]

def _lhs_signals(self):
signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
start=SignalSet())
return signals
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())

def _rhs_signals(self):
signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
start=SignalSet())
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
return self.test._rhs_signals() | signals

def __repr__(self):
Expand Down
81 changes: 53 additions & 28 deletions amaranth/hdl/_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self):
self.submodules = _ModuleBuilderSubmodules(self)
self.domains = _ModuleBuilderDomainSet(self)

self._statements = Statement.cast([])
self._statements = {}
self._ctrl_context = None
self._ctrl_stack = []

Expand Down Expand Up @@ -234,7 +234,7 @@ def If(self, cond):
"src_locs": [],
})
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
Expand All @@ -254,7 +254,7 @@ def Elif(self, cond):
if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Elif without preceding If")
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
Expand All @@ -273,7 +273,7 @@ def Else(self):
if if_data is None or if_data["depth"] != self.domain._depth:
raise SyntaxError("Else without preceding If/Elif")
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self.domain._depth += 1
yield
self._flush_ctrl()
Expand Down Expand Up @@ -341,7 +341,7 @@ def Case(self, *patterns):
continue
new_patterns = (*new_patterns, pattern.value)
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
Expand All @@ -364,7 +364,7 @@ def Default(self):
warnings.warn("A case defined after the default case will never be active",
SyntaxWarning, stacklevel=3)
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
Expand Down Expand Up @@ -416,7 +416,7 @@ def State(self, name):
if name not in fsm_data["encoding"]:
fsm_data["encoding"][name] = len(fsm_data["encoding"])
try:
_outer_case, self._statements = self._statements, []
_outer_case, self._statements = self._statements, {}
self._ctrl_context = None
yield
self._flush_ctrl()
Expand Down Expand Up @@ -453,28 +453,42 @@ def _pop_ctrl(self):
if_tests, if_bodies = data["tests"], data["bodies"]
if_src_locs = data["src_locs"]

tests, cases = [], OrderedDict()
for if_test, if_case in zip(if_tests + [None], if_bodies):
if if_test is not None:
if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)
domains = set()
for if_case in if_bodies:
domains |= set(if_case)

if if_test is not None:
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case
for domain in domains:
tests, cases = [], OrderedDict()
for if_test, if_case in zip(if_tests + [None], if_bodies):
if if_test is not None:
if len(if_test) != 1:
if_test = if_test.bool()
tests.append(if_test)

self._statements.append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
if if_test is not None:
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
else:
match = None
cases[match] = if_case.get(domain, [])

self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))

if name == "Switch":
switch_test, switch_cases = data["test"], data["cases"]
switch_case_src_locs = data["case_src_locs"]

self._statements.append(Switch(switch_test, switch_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))
domains = set()
for stmts in switch_cases.values():
domains |= set(stmts)

for domain in domains:
domain_cases = OrderedDict()
for patterns, stmts in switch_cases.items():
domain_cases[patterns] = stmts.get(domain, [])

self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
src_loc=src_loc, case_src_locs=switch_case_src_locs))

if name == "FSM":
fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \
Expand All @@ -490,10 +504,20 @@ def _pop_ctrl(self):
# The FSM is encoded such that the state with encoding 0 is always the reset state.
fsm_decoding.update((n, s) for s, n in fsm_encoding.items())
fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}"
self._statements.append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))

domains = set()
for stmts in fsm_states.values():
domains |= set(stmts)

for domain in domains:
domain_states = OrderedDict()
for state, stmts in fsm_states.items():
domain_states[state] = stmts.get(domain, [])

self._statements.setdefault(domain, []).append(Switch(fsm_signal,
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
for name in fsm_states}))

def _add_statement(self, assigns, domain, depth):
def domain_name(domain):
Expand Down Expand Up @@ -523,7 +547,7 @@ def domain_name(domain):
"already driven from d.{}"
.format(signal, domain_name(domain), domain_name(cd_curr)))

self._statements.append(stmt)
self._statements.setdefault(domain, []).append(stmt)

def _add_submodule(self, submodule, name=None):
if not hasattr(submodule, "elaborate"):
Expand Down Expand Up @@ -559,7 +583,8 @@ def elaborate(self, platform):
fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name)
for submodule in self._anon_submodules:
fragment.add_subfragment(Fragment.get(submodule, platform), None)
fragment.add_statements(self._statements)
for domain, statements in self._statements.items():
fragment.add_statements(domain, statements)
for signal, domain in self._driving.items():
fragment.add_driver(signal, domain)
fragment.add_domains(self._domains.values())
Expand Down
27 changes: 16 additions & 11 deletions amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .._utils import *
from .._unused import *
from ._ast import *
from ._ast import _StatementList
from ._cd import *


Expand Down Expand Up @@ -65,7 +66,7 @@ def get(obj, platform):
def __init__(self):
self.ports = SignalDict()
self.drivers = OrderedDict()
self.statements = []
self.statements = {}
self.domains = OrderedDict()
self.subfragments = []
self.attrs = OrderedDict()
Expand Down Expand Up @@ -127,10 +128,11 @@ def add_domains(self, *domains):
def iter_domains(self):
yield from self.domains

def add_statements(self, *stmts):
def add_statements(self, domain, *stmts):
assert domain is None or isinstance(domain, str)
for stmt in Statement.cast(stmts):
stmt._MustUse__used = True
self.statements.append(stmt)
self.statements.setdefault(domain, _StatementList()).append(stmt)

def add_subfragment(self, subfragment, name=None):
assert isinstance(subfragment, Fragment)
Expand Down Expand Up @@ -166,7 +168,8 @@ def _merge_subfragment(self, subfragment):
self.ports.update(subfragment.ports)
for domain, signal in subfragment.iter_drivers():
self.add_driver(signal, domain)
self.statements += subfragment.statements
for domain, statements in subfragment.statements.items():
self.statements.setdefault(domain, []).extend(statements)
self.subfragments += subfragment.subfragments

# Remove the merged subfragment.
Expand Down Expand Up @@ -387,9 +390,10 @@ def add_io(*sigs):

# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
for stmt in self.statements:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())
for stmts in self.statements.values():
for stmt in stmts:
add_uses(stmt._rhs_signals())
add_defs(stmt._lhs_signals())

for domain, _ in self.iter_sync():
cd = self.domains[domain]
Expand Down Expand Up @@ -572,10 +576,11 @@ def add_signal_name(signal):
if domain.rst is not None:
add_signal_name(domain.rst)

for statement in self.statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)
for statements in self.statements.values():
for statement in statements:
for signal in statement._lhs_signals() | statement._rhs_signals():
if not isinstance(signal, (ClockSignal, ResetSignal)):
add_signal_name(signal)

return signal_names

Expand Down
7 changes: 4 additions & 3 deletions amaranth/hdl/_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def elaborate(self, platform):
port._MustUse__used = True
if port.domain == "comb":
# Asynchronous port
f.add_statements(port.data.eq(self._array[port.addr]))
f.add_statements(None, port.data.eq(self._array[port.addr]))
f.add_driver(port.data)
else:
# Synchronous port
Expand All @@ -143,6 +143,7 @@ def elaborate(self, platform):
cond = write_port.en & (port.addr == write_port.addr)
data = Mux(cond, write_port.data, data)
f.add_statements(
port.domain,
Switch(port.en, {
1: port.data.eq(data)
})
Expand All @@ -155,10 +156,10 @@ def elaborate(self, platform):
offset = index * port.granularity
bits = slice(offset, offset + port.granularity)
write_data = self._array[port.addr][bits].eq(port.data[bits])
f.add_statements(Switch(en_bit, { 1: write_data }))
f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
else:
write_data = self._array[port.addr].eq(port.data)
f.add_statements(Switch(port.en, { 1: write_data }))
f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
for signal in self._array:
f.add_driver(signal, port.domain)
return f
Expand Down
Loading