From fec1fbf2480be552f9fb1c5ebc7548f051413bef Mon Sep 17 00:00:00 2001 From: Wanda Date: Fri, 9 Feb 2024 01:53:45 +0100 Subject: [PATCH] hdl.ir: associate statements with domains. Fixes #1079. --- amaranth/back/rtlil.py | 62 ++++++++++--------- amaranth/hdl/_ast.py | 13 ++-- amaranth/hdl/_dsl.py | 81 +++++++++++++++--------- amaranth/hdl/_ir.py | 27 ++++---- amaranth/hdl/_mem.py | 7 ++- amaranth/hdl/_xfrm.py | 24 +++++--- amaranth/sim/_pyrtl.py | 6 +- tests/test_hdl_dsl.py | 130 ++++++++++++++++++++++++++------------- tests/test_hdl_ir.py | 51 ++++++++++----- tests/test_hdl_xfrm.py | 95 +++++++++++++++------------- tests/test_lib_wiring.py | 12 ++-- tests/test_sim.py | 5 +- 12 files changed, 315 insertions(+), 198 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index 23f6b6c41..657b1d2fa 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -667,6 +667,7 @@ def __init__(self, state, rhs_compiler, lhs_compiler): self.rhs_compiler = rhs_compiler self.lhs_compiler = lhs_compiler + self._domain = None self._case = None self._test_cache = {} self._has_rhs = False @@ -865,8 +866,9 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): # Register all signals driven in the current fragment. This must be done first, as it # affects further codegen; e.g. whether \sig$next signals will be generated and used. - for domain, signal in fragment.iter_drivers(): - compiler_state.add_driven(signal, sync=domain is not None) + for domain, statements in fragment.statements.items(): + for signal in statements._lhs_signals(): + compiler_state.add_driven(signal, sync=domain is not None) # Transform all signals used as ports in the current fragment eagerly and outside of # any hierarchy, to make sure they get sensible (non-prefixed) names. @@ -925,32 +927,32 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): # Therefore, we translate the fragment as many times as there are independent groups # of signals (a group is a transitive closure of signals that appear together on LHS), # splitting them into many RTLIL (and thus Verilog) processes. - lhs_grouper = _xfrm.LHSGroupAnalyzer() - lhs_grouper.on_statements(fragment.statements) - - for group, group_signals in lhs_grouper.groups().items(): - lhs_group_filter = _xfrm.LHSGroupFilter(group_signals) - group_stmts = lhs_group_filter(fragment.statements) - - with module.process(name=f"$group_{group}") as process: - with process.case() as case: - # For every signal in comb domain, assign \sig$next to the reset value. - # For every signal in sync domains, assign \sig$next to the current - # value (\sig). - for domain, signal in fragment.iter_drivers(): - if signal not in group_signals: - continue - if domain is None: - prev_value = _ast.Const(signal.reset, signal.width) - else: - prev_value = signal - case.assign(lhs_compiler(signal), rhs_compiler(prev_value)) - - # Convert statements into decision trees. - stmt_compiler._case = case - stmt_compiler._has_rhs = False - stmt_compiler._wrap_assign = False - stmt_compiler(group_stmts) + for domain, statements in fragment.statements.items(): + lhs_grouper = _xfrm.LHSGroupAnalyzer() + lhs_grouper.on_statements(statements) + + for group, group_signals in lhs_grouper.groups().items(): + lhs_group_filter = _xfrm.LHSGroupFilter(group_signals) + group_stmts = lhs_group_filter(statements) + + with module.process(name=f"$group_{group}") as process: + with process.case() as case: + # For every signal in comb domain, assign \sig$next to the reset value. + # For every signal in sync domains, assign \sig$next to the current + # value (\sig). + for signal in group_signals: + if domain is None: + prev_value = _ast.Const(signal.reset, signal.width) + else: + prev_value = signal + case.assign(lhs_compiler(signal), rhs_compiler(prev_value)) + + # Convert statements into decision trees. + stmt_compiler._domain = domain + stmt_compiler._case = case + stmt_compiler._has_rhs = False + stmt_compiler._wrap_assign = False + stmt_compiler(group_stmts) # For every driven signal in the sync domain, create a flop of appropriate type. Which type # is appropriate depends on the domain: for domains with sync reset, it is a $dff, for @@ -998,8 +1000,8 @@ def _convert_fragment(builder, fragment, name_map, hierarchy): # to drive it to reset value arbitrarily) or to replace them with their reset value (which # removes valuable source location information). driven = _ast.SignalSet() - for domain, signals in fragment.iter_drivers(): - driven.update(flatten(signal._lhs_signals() for signal in signals)) + for domain, statements in fragment.statements.items(): + driven.update(statements._lhs_signals()) driven.update(fragment.iter_ports(dir="i")) driven.update(fragment.iter_ports(dir="io")) for subfragment, sub_name in fragment.subfragments: diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index c2fe98fa2..ae80e860b 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1706,6 +1706,12 @@ class _StatementList(list): def __repr__(self): return "({})".format(" ".join(map(repr, self))) + def _lhs_signals(self): + return union((s._lhs_signals() for s in self), start=SignalSet()) + + def _rhs_signals(self): + return union((s._rhs_signals() for s in self), start=SignalSet()) + class Statement: def __init__(self, *, src_loc_at=0): @@ -1837,13 +1843,10 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}) self.case_src_locs[new_keys] = case_src_locs[orig_keys] def _lhs_signals(self): - signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss), - start=SignalSet()) - return signals + return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet()) def _rhs_signals(self): - signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss), - start=SignalSet()) + signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet()) return self.test._rhs_signals() | signals def __repr__(self): diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 481f468ee..3fdb5809a 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -170,7 +170,7 @@ def __init__(self): self.submodules = _ModuleBuilderSubmodules(self) self.domains = _ModuleBuilderDomainSet(self) - self._statements = Statement.cast([]) + self._statements = {} self._ctrl_context = None self._ctrl_stack = [] @@ -234,7 +234,7 @@ def If(self, cond): "src_locs": [], }) try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self.domain._depth += 1 yield self._flush_ctrl() @@ -254,7 +254,7 @@ def Elif(self, cond): if if_data is None or if_data["depth"] != self.domain._depth: raise SyntaxError("Elif without preceding If") try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self.domain._depth += 1 yield self._flush_ctrl() @@ -273,7 +273,7 @@ def Else(self): if if_data is None or if_data["depth"] != self.domain._depth: raise SyntaxError("Else without preceding If/Elif") try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self.domain._depth += 1 yield self._flush_ctrl() @@ -341,7 +341,7 @@ def Case(self, *patterns): continue new_patterns = (*new_patterns, pattern.value) try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self._ctrl_context = None yield self._flush_ctrl() @@ -364,7 +364,7 @@ def Default(self): warnings.warn("A case defined after the default case will never be active", SyntaxWarning, stacklevel=3) try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self._ctrl_context = None yield self._flush_ctrl() @@ -416,7 +416,7 @@ def State(self, name): if name not in fsm_data["encoding"]: fsm_data["encoding"][name] = len(fsm_data["encoding"]) try: - _outer_case, self._statements = self._statements, [] + _outer_case, self._statements = self._statements, {} self._ctrl_context = None yield self._flush_ctrl() @@ -453,28 +453,42 @@ def _pop_ctrl(self): if_tests, if_bodies = data["tests"], data["bodies"] if_src_locs = data["src_locs"] - tests, cases = [], OrderedDict() - for if_test, if_case in zip(if_tests + [None], if_bodies): - if if_test is not None: - if len(if_test) != 1: - if_test = if_test.bool() - tests.append(if_test) + domains = set() + for if_case in if_bodies: + domains |= set(if_case) - if if_test is not None: - match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-") - else: - match = None - cases[match] = if_case + for domain in domains: + tests, cases = [], OrderedDict() + for if_test, if_case in zip(if_tests + [None], if_bodies): + if if_test is not None: + if len(if_test) != 1: + if_test = if_test.bool() + tests.append(if_test) - self._statements.append(Switch(Cat(tests), cases, - src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs)))) + if if_test is not None: + match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-") + else: + match = None + cases[match] = if_case.get(domain, []) + + self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases, + src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs)))) if name == "Switch": switch_test, switch_cases = data["test"], data["cases"] switch_case_src_locs = data["case_src_locs"] - self._statements.append(Switch(switch_test, switch_cases, - src_loc=src_loc, case_src_locs=switch_case_src_locs)) + domains = set() + for stmts in switch_cases.values(): + domains |= set(stmts) + + for domain in domains: + domain_cases = OrderedDict() + for patterns, stmts in switch_cases.items(): + domain_cases[patterns] = stmts.get(domain, []) + + self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases, + src_loc=src_loc, case_src_locs=switch_case_src_locs)) if name == "FSM": fsm_signal, fsm_reset, fsm_encoding, fsm_decoding, fsm_states = \ @@ -490,10 +504,20 @@ def _pop_ctrl(self): # The FSM is encoded such that the state with encoding 0 is always the reset state. fsm_decoding.update((n, s) for s, n in fsm_encoding.items()) fsm_signal.decoder = lambda n: f"{fsm_decoding[n]}/{n}" - self._statements.append(Switch(fsm_signal, - OrderedDict((fsm_encoding[name], stmts) for name, stmts in fsm_states.items()), - src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name] - for name in fsm_states})) + + domains = set() + for stmts in fsm_states.values(): + domains |= set(stmts) + + for domain in domains: + domain_states = OrderedDict() + for state, stmts in fsm_states.items(): + domain_states[state] = stmts.get(domain, []) + + self._statements.setdefault(domain, []).append(Switch(fsm_signal, + OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()), + src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name] + for name in fsm_states})) def _add_statement(self, assigns, domain, depth): def domain_name(domain): @@ -523,7 +547,7 @@ def domain_name(domain): "already driven from d.{}" .format(signal, domain_name(domain), domain_name(cd_curr))) - self._statements.append(stmt) + self._statements.setdefault(domain, []).append(stmt) def _add_submodule(self, submodule, name=None): if not hasattr(submodule, "elaborate"): @@ -559,7 +583,8 @@ def elaborate(self, platform): fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name) for submodule in self._anon_submodules: fragment.add_subfragment(Fragment.get(submodule, platform), None) - fragment.add_statements(self._statements) + for domain, statements in self._statements.items(): + fragment.add_statements(domain, statements) for signal, domain in self._driving.items(): fragment.add_driver(signal, domain) fragment.add_domains(self._domains.values()) diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index cbdbfed42..bb671248a 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -7,6 +7,7 @@ from .._utils import * from .._unused import * from ._ast import * +from ._ast import _StatementList from ._cd import * @@ -65,7 +66,7 @@ def get(obj, platform): def __init__(self): self.ports = SignalDict() self.drivers = OrderedDict() - self.statements = [] + self.statements = {} self.domains = OrderedDict() self.subfragments = [] self.attrs = OrderedDict() @@ -127,10 +128,11 @@ def add_domains(self, *domains): def iter_domains(self): yield from self.domains - def add_statements(self, *stmts): + def add_statements(self, domain, *stmts): + assert domain is None or isinstance(domain, str) for stmt in Statement.cast(stmts): stmt._MustUse__used = True - self.statements.append(stmt) + self.statements.setdefault(domain, _StatementList()).append(stmt) def add_subfragment(self, subfragment, name=None): assert isinstance(subfragment, Fragment) @@ -166,7 +168,8 @@ def _merge_subfragment(self, subfragment): self.ports.update(subfragment.ports) for domain, signal in subfragment.iter_drivers(): self.add_driver(signal, domain) - self.statements += subfragment.statements + for domain, statements in subfragment.statements.items(): + self.statements.setdefault(domain, []).extend(statements) self.subfragments += subfragment.subfragments # Remove the merged subfragment. @@ -387,9 +390,10 @@ def add_io(*sigs): # Collect all signals we're driving (on LHS of statements), and signals we're using # (on RHS of statements, or in clock domains). - for stmt in self.statements: - add_uses(stmt._rhs_signals()) - add_defs(stmt._lhs_signals()) + for stmts in self.statements.values(): + for stmt in stmts: + add_uses(stmt._rhs_signals()) + add_defs(stmt._lhs_signals()) for domain, _ in self.iter_sync(): cd = self.domains[domain] @@ -572,10 +576,11 @@ def add_signal_name(signal): if domain.rst is not None: add_signal_name(domain.rst) - for statement in self.statements: - for signal in statement._lhs_signals() | statement._rhs_signals(): - if not isinstance(signal, (ClockSignal, ResetSignal)): - add_signal_name(signal) + for statements in self.statements.values(): + for statement in statements: + for signal in statement._lhs_signals() | statement._rhs_signals(): + if not isinstance(signal, (ClockSignal, ResetSignal)): + add_signal_name(signal) return signal_names diff --git a/amaranth/hdl/_mem.py b/amaranth/hdl/_mem.py index a52cec5b1..36c66720c 100644 --- a/amaranth/hdl/_mem.py +++ b/amaranth/hdl/_mem.py @@ -124,7 +124,7 @@ def elaborate(self, platform): port._MustUse__used = True if port.domain == "comb": # Asynchronous port - f.add_statements(port.data.eq(self._array[port.addr])) + f.add_statements(None, port.data.eq(self._array[port.addr])) f.add_driver(port.data) else: # Synchronous port @@ -143,6 +143,7 @@ def elaborate(self, platform): cond = write_port.en & (port.addr == write_port.addr) data = Mux(cond, write_port.data, data) f.add_statements( + port.domain, Switch(port.en, { 1: port.data.eq(data) }) @@ -155,10 +156,10 @@ def elaborate(self, platform): offset = index * port.granularity bits = slice(offset, offset + port.granularity) write_data = self._array[port.addr][bits].eq(port.data[bits]) - f.add_statements(Switch(en_bit, { 1: write_data })) + f.add_statements(port.domain, Switch(en_bit, { 1: write_data })) else: write_data = self._array[port.addr].eq(port.data) - f.add_statements(Switch(port.en, { 1: write_data })) + f.add_statements(port.domain, Switch(port.en, { 1: write_data })) for signal in self._array: f.add_driver(signal, port.domain) return f diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index f23bcba04..ea17af87b 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -228,9 +228,11 @@ def map_domains(self, fragment, new_fragment): def map_statements(self, fragment, new_fragment): if hasattr(self, "on_statement"): - new_fragment.add_statements(map(self.on_statement, fragment.statements)) + for domain, statements in fragment.statements.items(): + new_fragment.add_statements(domain, map(self.on_statement, statements)) else: - new_fragment.add_statements(fragment.statements) + for domain, statements in fragment.statements.items(): + new_fragment.add_statements(domain, statements) def map_drivers(self, fragment, new_fragment): for domain, signal in fragment.iter_drivers(): @@ -397,9 +399,9 @@ def on_fragment(self, fragment): else: self.defined_domains.add(domain_name) - self.on_statements(fragment.statements) - for domain_name in fragment.drivers: + for domain_name, statements in fragment.statements.items(): self._add_used_domain(domain_name) + self.on_statements(statements) for subfragment, name in fragment.subfragments: self.on_fragment(subfragment) @@ -442,6 +444,13 @@ def map_domains(self, fragment, new_fragment): assert cd.name == self.domain_map[domain] new_fragment.add_domains(cd) + def map_statements(self, fragment, new_fragment): + for domain, statements in fragment.statements.items(): + new_fragment.add_statements( + self.domain_map.get(domain, domain), + map(self.on_statement, statements) + ) + def map_drivers(self, fragment, new_fragment): for domain, signals in fragment.drivers.items(): if domain in self.domain_map: @@ -499,7 +508,7 @@ def _insert_resets(self, fragment): continue stmts = [signal.eq(Const(signal.reset, signal.width)) for signal in signals if not signal.reset_less] - fragment.add_statements(Switch(domain.rst, {1: stmts})) + fragment.add_statements(domain_name, Switch(domain.rst, {1: stmts})) def on_fragment(self, fragment): self.domains = fragment.domains @@ -571,6 +580,7 @@ def on_Switch(self, stmt): self.on_statements(case_stmts) def on_statements(self, stmts): + assert not isinstance(stmts, str) for stmt in stmts: self.on_statement(stmt) @@ -624,13 +634,13 @@ def __call__(self, value, *, src_loc_at=0): class ResetInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less] - fragment.add_statements(Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) + fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) class EnableInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(s) for s in signals] - fragment.add_statements(Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) + fragment.add_statements(domain, Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 121804d39..30f5362c7 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -5,7 +5,7 @@ from ..hdl import * from ..hdl._ast import SignalSet -from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter +from ..hdl._xfrm import ValueVisitor, StatementVisitor from ._base import BaseProcess @@ -409,9 +409,9 @@ def __init__(self, state): def __call__(self, fragment): processes = set() - for domain_name, domain_signals in fragment.drivers.items(): - domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements) + for domain_name, domain_stmts in fragment.statements.items(): domain_process = PyRTLProcess(is_comb=domain_name is None) + domain_signals = domain_stmts._lhs_signals() emitter = _PythonEmitter() emitter.append(f"def run():") diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index 9914da4cb..e094e20fa 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -34,7 +34,7 @@ def test_d_comb(self): m.d.comb += self.c1.eq(1) m._flush() self.assertEqual(m._driving[self.c1], None) - self.assertRepr(m._statements, """( + self.assertRepr(m._statements[None], """( (eq (sig c1) (const 1'd1)) )""") @@ -43,7 +43,7 @@ def test_d_sync(self): m.d.sync += self.c1.eq(1) m._flush() self.assertEqual(m._driving[self.c1], "sync") - self.assertRepr(m._statements, """( + self.assertRepr(m._statements["sync"], """( (eq (sig c1) (const 1'd1)) )""") @@ -52,7 +52,7 @@ def test_d_pix(self): m.d.pix += self.c1.eq(1) m._flush() self.assertEqual(m._driving[self.c1], "pix") - self.assertRepr(m._statements, """( + self.assertRepr(m._statements["pix"], """( (eq (sig c1) (const 1'd1)) )""") @@ -61,7 +61,7 @@ def test_d_index(self): m.d["pix"] += self.c1.eq(1) m._flush() self.assertEqual(m._driving[self.c1], "pix") - self.assertRepr(m._statements, """( + self.assertRepr(m._statements["pix"], """( (eq (sig c1) (const 1'd1)) )""") @@ -118,7 +118,7 @@ def test_d_suspicious(self): def test_clock_signal(self): m = Module() m.d.comb += ClockSignal("pix").eq(ClockSignal()) - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (eq (clk pix) (clk sync)) ) @@ -127,7 +127,7 @@ def test_clock_signal(self): def test_reset_signal(self): m = Module() m.d.comb += ResetSignal("pix").eq(1) - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (eq (rst pix) (const 1'd1)) ) @@ -138,7 +138,7 @@ def test_If(self): with m.If(self.s1): m.d.comb += self.c1.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1)) (case 1 (eq (sig c1) (const 1'd1))) @@ -147,16 +147,40 @@ def test_If(self): """) def test_If_Elif(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.Elif(self.s2): + m.d.comb += self.c2.eq(0) + m._flush() + self.assertRepr(m._statements[None], """ + ( + (switch (cat (sig s1) (sig s2)) + (case -1 (eq (sig c1) (const 1'd1))) + (case 1- (eq (sig c2) (const 1'd0))) + ) + ) + """) + + def test_If_Elif_multi(self): m = Module() with m.If(self.s1): m.d.comb += self.c1.eq(1) with m.Elif(self.s2): m.d.sync += self.c2.eq(0) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1) (sig s2)) (case -1 (eq (sig c1) (const 1'd1))) + (case 1- ) + ) + ) + """) + self.assertRepr(m._statements["sync"], """ + ( + (switch (cat (sig s1) (sig s2)) + (case -1 ) (case 1- (eq (sig c2) (const 1'd0))) ) ) @@ -171,15 +195,24 @@ def test_If_Elif_Else(self): with m.Else(): m.d.comb += self.c3.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1) (sig s2)) (case -1 (eq (sig c1) (const 1'd1))) - (case 1- (eq (sig c2) (const 1'd0))) + (case 1- ) (default (eq (sig c3) (const 1'd1))) ) ) """) + self.assertRepr(m._statements["sync"], """ + ( + (switch (cat (sig s1) (sig s2)) + (case -1 ) + (case 1- (eq (sig c2) (const 1'd0))) + (default ) + ) + ) + """) def test_If_If(self): m = Module() @@ -188,7 +221,7 @@ def test_If_If(self): with m.If(self.s2): m.d.comb += self.c2.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1)) (case 1 (eq (sig c1) (const 1'd1))) @@ -206,7 +239,7 @@ def test_If_nested_If(self): with m.If(self.s2): m.d.comb += self.c2.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1)) (case 1 (eq (sig c1) (const 1'd1)) @@ -227,7 +260,7 @@ def test_If_dangling_Else(self): with m.Else(): m.d.comb += self.c3.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (sig s1)) (case 1 @@ -298,7 +331,7 @@ def test_If_wide(self): with m.If(self.w1): m.d.comb += self.c1.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (b (sig w1))) (case 1 (eq (sig c1) (const 1'd1))) @@ -356,7 +389,7 @@ def test_Switch(self): with m.Case("1 0--"): m.d.comb += self.c2.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig w1) (case 0011 (eq (sig c1) (const 1'd1))) @@ -374,7 +407,7 @@ def test_Switch_empty_Case(self): with m.Case(): m.d.comb += self.c2.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig w1) (case 0011 (eq (sig c1) (const 1'd1))) @@ -390,7 +423,7 @@ def test_Switch_default_Default(self): with m.Default(): m.d.comb += self.c2.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig w1) (case 0011 (eq (sig c1) (const 1'd1))) @@ -405,7 +438,7 @@ def test_Switch_const_test(self): with m.Case(1): m.d.comb += self.c1.eq(1) m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (const 1'd1) (case 1 (eq (sig c1) (const 1'd1))) @@ -422,7 +455,7 @@ class Color(Enum): with m.Switch(se): with m.Case(Color.RED): m.d.comb += self.c1.eq(1) - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig se) (case 01 (eq (sig c1) (const 1'd1))) @@ -439,7 +472,7 @@ class Color(Enum, shape=1): with m.Switch(se): with m.Case(Cat(Color.RED, Color.BLUE)): m.d.comb += self.c1.eq(1) - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig se) (case 10 (eq (sig c1) (const 1'd1))) @@ -451,26 +484,23 @@ def test_Case_width_wrong(self): class Color(Enum): RED = 0b10101010 m = Module() + dummy = Signal() with m.Switch(self.w1): with self.assertRaisesRegex(SyntaxError, r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"): with m.Case("--"): - pass + m.d.comb += dummy.eq(0) with self.assertWarnsRegex(SyntaxWarning, r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has " r"width 4\); comparison will never be true$"): with m.Case(0b10110): - pass + m.d.comb += dummy.eq(0) with self.assertWarnsRegex(SyntaxWarning, r"^Case pattern '' \(8'10101010\) is wider than switch value " r"\(which has width 4\); comparison will never be true$"): with m.Case(Color.RED): - pass - self.assertRepr(m._statements, """ - ( - (switch (sig w1) ) - ) - """) + m.d.comb += dummy.eq(0) + self.assertEqual(m._statements, {}) def test_Case_bits_wrong(self): m = Module() @@ -549,11 +579,20 @@ def test_FSM_basic(self): with m.If(c): m.next = "FIRST" m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig fsm_state) (case 0 (eq (sig a) (const 1'd1)) + ) + (case 1 ) + ) + ) + """) + self.assertRepr(m._statements["sync"], """ + ( + (switch (sig fsm_state) + (case 0 (eq (sig fsm_state) (const 1'd1)) ) (case 1 @@ -594,11 +633,20 @@ def test_FSM_reset(self): with m.State("SECOND"): m.next = "FIRST" m._flush() - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (sig fsm_state) (case 0 (eq (sig a) (const 1'd0)) + ) + (case 1 ) + ) + ) + """) + self.assertRepr(m._statements["sync"], """ + ( + (switch (sig fsm_state) + (case 0 (eq (sig fsm_state) (const 1'd1)) ) (case 1 @@ -622,16 +670,10 @@ def test_FSM_ongoing(self): m._flush() self.assertEqual(m._generated["fsm"].state.reset, 1) self.maxDiff = 10000 - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (eq (sig b) (== (sig fsm_state) (const 1'd0))) (eq (sig a) (== (sig fsm_state) (const 1'd1))) - (switch (sig fsm_state) - (case 1 - ) - (case 0 - ) - ) ) """) @@ -639,9 +681,7 @@ def test_FSM_empty(self): m = Module() with m.FSM(): pass - self.assertRepr(m._statements, """ - () - """) + self.assertEqual(m._statements, {}) def test_FSM_wrong_domain(self): m = Module() @@ -713,7 +753,7 @@ def test_auto_pop_ctrl(self): with m.If(self.w1): m.d.comb += self.c1.eq(1) m.d.comb += self.c2.eq(1) - self.assertRepr(m._statements, """ + self.assertRepr(m._statements[None], """ ( (switch (cat (b (sig w1))) (case 1 (eq (sig c1) (const 1'd1))) @@ -830,7 +870,7 @@ def test_lower(self): m1.submodules.foo = m2 f1 = m1.elaborate(platform=None) - self.assertRepr(f1.statements, """ + self.assertRepr(f1.statements[None], """ ( (eq (sig c1) (sig s1)) ) @@ -841,9 +881,13 @@ def test_lower(self): self.assertEqual(len(f1.subfragments), 1) (f2, f2_name), = f1.subfragments self.assertEqual(f2_name, "foo") - self.assertRepr(f2.statements, """ + self.assertRepr(f2.statements[None], """ ( (eq (sig c2) (sig s2)) + ) + """) + self.assertRepr(f2.statements["sync"], """ + ( (eq (sig c3) (sig s3)) ) """) diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 0b6abc954..229259b43 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -100,6 +100,7 @@ def test_iter_signals(self): def test_self_contained(self): f = Fragment() f.add_statements( + None, self.c1.eq(self.s1), self.s1.eq(self.c1) ) @@ -110,6 +111,7 @@ def test_self_contained(self): def test_infer_input(self): f = Fragment() f.add_statements( + None, self.c1.eq(self.s1) ) @@ -121,6 +123,7 @@ def test_infer_input(self): def test_request_output(self): f = Fragment() f.add_statements( + None, self.c1.eq(self.s1) ) @@ -133,10 +136,12 @@ def test_request_output(self): def test_input_in_subfragment(self): f1 = Fragment() f1.add_statements( + None, self.c1.eq(self.s1) ) f2 = Fragment() f2.add_statements( + None, self.s1.eq(0) ) f1.add_subfragment(f2) @@ -150,6 +155,7 @@ def test_input_only_in_subfragment(self): f1 = Fragment() f2 = Fragment() f2.add_statements( + None, self.c1.eq(self.s1) ) f1.add_subfragment(f2) @@ -164,10 +170,12 @@ def test_input_only_in_subfragment(self): def test_output_from_subfragment(self): f1 = Fragment() f1.add_statements( + None, self.c1.eq(0) ) f2 = Fragment() f2.add_statements( + None, self.c2.eq(1) ) f1.add_subfragment(f2) @@ -183,15 +191,18 @@ def test_output_from_subfragment(self): def test_output_from_subfragment_2(self): f1 = Fragment() f1.add_statements( + None, self.c1.eq(self.s1) ) f2 = Fragment() f2.add_statements( + None, self.c2.eq(self.s1) ) f1.add_subfragment(f2) f3 = Fragment() f3.add_statements( + None, self.s1.eq(0) ) f2.add_subfragment(f3) @@ -205,11 +216,13 @@ def test_input_output_sibling(self): f1 = Fragment() f2 = Fragment() f2.add_statements( + None, self.c1.eq(self.c2) ) f1.add_subfragment(f2) f3 = Fragment() f3.add_statements( + None, self.c2.eq(0) ) f3.add_driver(self.c2) @@ -222,12 +235,14 @@ def test_output_input_sibling(self): f1 = Fragment() f2 = Fragment() f2.add_statements( + None, self.c2.eq(0) ) f2.add_driver(self.c2) f1.add_subfragment(f2) f3 = Fragment() f3.add_statements( + None, self.c1.eq(self.c2) ) f1.add_subfragment(f3) @@ -239,6 +254,7 @@ def test_input_cd(self): sync = ClockDomain() f = Fragment() f.add_statements( + "sync", self.c1.eq(self.s1) ) f.add_domains(sync) @@ -255,6 +271,7 @@ def test_input_cd_reset_less(self): sync = ClockDomain(reset_less=True) f = Fragment() f.add_statements( + "sync", self.c1.eq(self.s1) ) f.add_domains(sync) @@ -490,7 +507,7 @@ def test_propagate(self): def test_propagate_missing(self): s1 = Signal() f1 = Fragment() - f1.add_driver(s1, "sync") + f1.add_statements("sync", s1.eq(1)) with self.assertRaisesRegex(DomainError, r"^Domain 'sync' is used but not defined$"): @@ -499,7 +516,7 @@ def test_propagate_missing(self): def test_propagate_create_missing(self): s1 = Signal() f1 = Fragment() - f1.add_driver(s1, "sync") + f1.add_statements("sync", s1.eq(1)) f2 = Fragment() f1.add_subfragment(f2) @@ -512,7 +529,7 @@ def test_propagate_create_missing(self): def test_propagate_create_missing_fragment(self): s1 = Signal() f1 = Fragment() - f1.add_driver(s1, "sync") + f1.add_statements("sync", s1.eq(1)) cd = ClockDomain("sync") f2 = Fragment() @@ -529,7 +546,7 @@ def test_propagate_create_missing_fragment(self): def test_propagate_create_missing_fragment_many_domains(self): s1 = Signal() f1 = Fragment() - f1.add_driver(s1, "sync") + f1.add_statements("sync", s1.eq(1)) cd_por = ClockDomain("por") cd_sync = ClockDomain("sync") @@ -548,7 +565,7 @@ def test_propagate_create_missing_fragment_many_domains(self): def test_propagate_create_missing_fragment_wrong(self): s1 = Signal() f1 = Fragment() - f1.add_driver(s1, "sync") + f1.add_statements("sync", s1.eq(1)) f2 = Fragment() f2.add_domains(ClockDomain("foo")) @@ -566,7 +583,7 @@ def setUp_self_sub(self): self.c2 = Signal() self.f1 = Fragment() - self.f1.add_statements(self.c1.eq(0)) + self.f1.add_statements("sync", self.c1.eq(0)) self.f1.add_driver(self.s1) self.f1.add_driver(self.c1, "sync") @@ -574,7 +591,7 @@ def setUp_self_sub(self): self.f1.add_subfragment(self.f1a, "f1a") self.f2 = Fragment() - self.f2.add_statements(self.c2.eq(1)) + self.f2.add_statements("sync", self.c2.eq(1)) self.f2.add_driver(self.s1) self.f2.add_driver(self.c2, "sync") self.f1.add_subfragment(self.f2) @@ -594,7 +611,7 @@ def test_conflict_self_sub(self): (self.f1b, "f1b"), (self.f2a, "f2a"), ]) - self.assertRepr(self.f1.statements, """ + self.assertRepr(self.f1.statements["sync"], """ ( (eq (sig c1) (const 1'd0)) (eq (sig c2) (const 1'd1)) @@ -629,12 +646,12 @@ def setUp_sub_sub(self): self.f2 = Fragment() self.f2.add_driver(self.s1) - self.f2.add_statements(self.c1.eq(0)) + self.f2.add_statements(None, self.c1.eq(0)) self.f1.add_subfragment(self.f2) self.f3 = Fragment() self.f3.add_driver(self.s1) - self.f3.add_statements(self.c2.eq(1)) + self.f3.add_statements(None, self.c2.eq(1)) self.f1.add_subfragment(self.f3) def test_conflict_sub_sub(self): @@ -642,7 +659,7 @@ def test_conflict_sub_sub(self): self.f1._resolve_hierarchy_conflicts(mode="silent") self.assertEqual(self.f1.subfragments, []) - self.assertRepr(self.f1.statements, """ + self.assertRepr(self.f1.statements[None], """ ( (eq (sig c1) (const 1'd0)) (eq (sig c2) (const 1'd1)) @@ -658,12 +675,12 @@ def setUp_self_subsub(self): self.f1.add_driver(self.s1) self.f2 = Fragment() - self.f2.add_statements(self.c1.eq(0)) + self.f2.add_statements(None, self.c1.eq(0)) self.f1.add_subfragment(self.f2) self.f3 = Fragment() self.f3.add_driver(self.s1) - self.f3.add_statements(self.c2.eq(1)) + self.f3.add_statements(None, self.c2.eq(1)) self.f2.add_subfragment(self.f3) def test_conflict_self_subsub(self): @@ -671,7 +688,7 @@ def test_conflict_self_subsub(self): self.f1._resolve_hierarchy_conflicts(mode="silent") self.assertEqual(self.f1.subfragments, []) - self.assertRepr(self.f1.statements, """ + self.assertRepr(self.f1.statements[None], """ ( (eq (sig c1) (const 1'd0)) (eq (sig c2) (const 1'd1)) @@ -848,11 +865,11 @@ def test_assign_names_to_signals(self): f.add_domains(cd_sync_norst := ClockDomain(reset_less=True)) f.add_ports((i, rst), dir="i") f.add_ports((o1, o2, o3), dir="o") - f.add_statements([o1.eq(0)]) + f.add_statements(None, [o1.eq(0)]) f.add_driver(o1, domain=None) - f.add_statements([o2.eq(i1)]) + f.add_statements("sync", [o2.eq(i1)]) f.add_driver(o2, domain="sync") - f.add_statements([o3.eq(i1)]) + f.add_statements("sync_norst", [o3.eq(i1)]) f.add_driver(o3, domain="sync_norst") names = f._assign_names_to_signals() diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 1f9658907..a8724cd28 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -26,26 +26,35 @@ def setUp(self): def test_rename_signals(self): f = Fragment() f.add_statements( + None, self.s1.eq(ClockSignal()), ResetSignal().eq(self.s2), - self.s3.eq(0), self.s4.eq(ClockSignal("other")), self.s5.eq(ResetSignal("other")), ) + f.add_statements( + "sync", + self.s3.eq(0), + ) f.add_driver(self.s1, None) f.add_driver(self.s2, None) f.add_driver(self.s3, "sync") f = DomainRenamer("pix")(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s1) (clk pix)) (eq (rst pix) (sig s2)) - (eq (sig s3) (const 1'd0)) (eq (sig s4) (clk other)) (eq (sig s5) (rst other)) ) """) + self.assertRepr(f.statements["pix"], """ + ( + (eq (sig s3) (const 1'd0)) + ) + """) + self.assertFalse("sync" in f.statements) self.assertEqual(f.drivers, { None: SignalSet((self.s1, self.s2)), "pix": SignalSet((self.s3,)), @@ -54,12 +63,13 @@ def test_rename_signals(self): def test_rename_multi(self): f = Fragment() f.add_statements( + None, self.s1.eq(ClockSignal()), self.s2.eq(ResetSignal("other")), ) f = DomainRenamer({"sync": "pix", "other": "pix2"})(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s1) (clk pix)) (eq (sig s2) (rst pix2)) @@ -86,12 +96,13 @@ def test_rename_cd_preserves_allow_reset_less(self): f = Fragment() f.add_domains(cd_pix) f.add_statements( + None, self.s1.eq(ResetSignal(allow_reset_less=True)), ) f = DomainRenamer("pix")(f) f = DomainLowerer()(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s1) (const 1'd0)) ) @@ -151,11 +162,12 @@ def test_lower_clk(self): f = Fragment() f.add_domains(sync) f.add_statements( + None, self.s.eq(ClockSignal("sync")) ) f = DomainLowerer()(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s) (sig clk)) ) @@ -166,11 +178,12 @@ def test_lower_rst(self): f = Fragment() f.add_domains(sync) f.add_statements( + None, self.s.eq(ResetSignal("sync")) ) f = DomainLowerer()(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s) (sig rst)) ) @@ -181,11 +194,12 @@ def test_lower_rst_reset_less(self): f = Fragment() f.add_domains(sync) f.add_statements( + None, self.s.eq(ResetSignal("sync", allow_reset_less=True)) ) f = DomainLowerer()(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements[None], """ ( (eq (sig s) (const 1'd0)) ) @@ -208,6 +222,7 @@ def test_lower_drivers(self): def test_lower_wrong_domain(self): f = Fragment() f.add_statements( + None, self.s.eq(ClockSignal("xxx")) ) @@ -220,6 +235,7 @@ def test_lower_wrong_reset_less_domain(self): f = Fragment() f.add_domains(sync) f.add_statements( + None, self.s.eq(ResetSignal("sync")) ) @@ -368,12 +384,13 @@ def setUp(self): def test_reset_default(self): f = Fragment() f.add_statements( + "sync", self.s1.eq(1) ) f.add_driver(self.s1, "sync") f = ResetInserter(self.c1)(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) (switch (sig c1) @@ -384,18 +401,20 @@ def test_reset_default(self): def test_reset_cd(self): f = Fragment() - f.add_statements( - self.s1.eq(1), - self.s2.eq(0), - ) + f.add_statements("sync", self.s1.eq(1)) + f.add_statements("pix", self.s2.eq(0)) f.add_domains(ClockDomain("sync")) f.add_driver(self.s1, "sync") f.add_driver(self.s2, "pix") f = ResetInserter({"pix": self.c1})(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) + ) + """) + self.assertRepr(f.statements["pix"], """ + ( (eq (sig s2) (const 1'd0)) (switch (sig c1) (case 1 (eq (sig s2) (const 1'd1))) @@ -405,13 +424,11 @@ def test_reset_cd(self): def test_reset_value(self): f = Fragment() - f.add_statements( - self.s2.eq(0) - ) + f.add_statements("sync", self.s2.eq(0)) f.add_driver(self.s2, "sync") f = ResetInserter(self.c1)(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s2) (const 1'd0)) (switch (sig c1) @@ -422,13 +439,11 @@ def test_reset_value(self): def test_reset_less(self): f = Fragment() - f.add_statements( - self.s3.eq(0) - ) + f.add_statements("sync", self.s3.eq(0)) f.add_driver(self.s3, "sync") f = ResetInserter(self.c1)(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s3) (const 1'd0)) (switch (sig c1) @@ -447,13 +462,11 @@ def setUp(self): def test_enable_default(self): f = Fragment() - f.add_statements( - self.s1.eq(1) - ) + f.add_statements("sync", self.s1.eq(1)) f.add_driver(self.s1, "sync") f = EnableInserter(self.c1)(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) (switch (sig c1) @@ -464,17 +477,19 @@ def test_enable_default(self): def test_enable_cd(self): f = Fragment() - f.add_statements( - self.s1.eq(1), - self.s2.eq(0), - ) + f.add_statements("sync", self.s1.eq(1)) + f.add_statements("pix", self.s2.eq(0)) f.add_driver(self.s1, "sync") f.add_driver(self.s2, "pix") f = EnableInserter({"pix": self.c1})(f) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) + ) + """) + self.assertRepr(f.statements["pix"], """ + ( (eq (sig s2) (const 1'd0)) (switch (sig c1) (case 0 (eq (sig s2) (sig s2))) @@ -484,21 +499,17 @@ def test_enable_cd(self): def test_enable_subfragment(self): f1 = Fragment() - f1.add_statements( - self.s1.eq(1) - ) + f1.add_statements("sync", self.s1.eq(1)) f1.add_driver(self.s1, "sync") f2 = Fragment() - f2.add_statements( - self.s2.eq(1) - ) + f2.add_statements("sync", self.s2.eq(1)) f2.add_driver(self.s2, "sync") f1.add_subfragment(f2) f1 = EnableInserter(self.c1)(f1) (f2, _), = f1.subfragments - self.assertRepr(f1.statements, """ + self.assertRepr(f1.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) (switch (sig c1) @@ -506,7 +517,7 @@ def test_enable_subfragment(self): ) ) """) - self.assertRepr(f2.statements, """ + self.assertRepr(f2.statements["sync"], """ ( (eq (sig s2) (const 1'd1)) (switch (sig c1) @@ -542,9 +553,7 @@ def __init__(self): def elaborate(self, platform): f = Fragment() - f.add_statements( - self.s1.eq(1) - ) + f.add_statements("sync", self.s1.eq(1)) f.add_driver(self.s1, "sync") return f @@ -569,7 +578,7 @@ def test_composition(self): self.assertIs(te1, te2) f = Fragment.get(te2, None) - self.assertRepr(f.statements, """ + self.assertRepr(f.statements["sync"], """ ( (eq (sig s1) (const 1'd1)) (switch (sig c1) diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py index 26c507679..78cb22bef 100644 --- a/tests/test_lib_wiring.py +++ b/tests/test_lib_wiring.py @@ -889,7 +889,7 @@ class Cycle(enum.Enum): m = Module() connect(m, src=src, snk=snk) - self.assertEqual([repr(stmt) for stmt in m._statements], [ + self.assertEqual([repr(stmt) for stmt in m._statements[None]], [ '(eq (sig snk__addr) (sig src__addr))', '(eq (sig snk__cycle) (sig src__cycle))', '(eq (sig src__r_data) (sig snk__r_data))', @@ -903,7 +903,7 @@ def test_const_in_out(self): a=Const(1)), q=NS(signature=Signature({"a": In(1)}), a=Const(1))) - self.assertEqual(m._statements, []) + self.assertEqual(m._statements, {}) def test_nested(self): m = Module() @@ -912,7 +912,7 @@ def test_nested(self): a=NS(signature=Signature({"f": Out(1)}), f=Signal(name='p__a'))), q=NS(signature=Signature({"a": In(Signature({"f": Out(1)}))}), a=NS(signature=Signature({"f": Out(1)}).flip(), f=Signal(name='q__a')))) - self.assertEqual([repr(stmt) for stmt in m._statements], [ + self.assertEqual([repr(stmt) for stmt in m._statements[None]], [ '(eq (sig q__a) (sig p__a))' ]) @@ -931,7 +931,7 @@ def test_unordered(self): g=Signal(name="q__b__g"), f=Signal(name="q__b__f")), a=Signal(name="q__a"))) - self.assertEqual([repr(stmt) for stmt in m._statements], [ + self.assertEqual([repr(stmt) for stmt in m._statements[None]], [ '(eq (sig q__a) (sig p__a))', '(eq (sig q__b__f) (sig p__b__f))', '(eq (sig q__b__g) (sig p__b__g))', @@ -942,7 +942,7 @@ def test_dimension(self): m = Module() connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) - self.assertEqual([repr(stmt) for stmt in m._statements], [ + self.assertEqual([repr(stmt) for stmt in m._statements[None]], [ '(eq (sig q__a__0) (sig p__a__0))', '(eq (sig q__a__1) (sig p__a__1))' ]) @@ -952,7 +952,7 @@ def test_dimension_multi(self): m = Module() connect(m, p=sig.create(path=('p',)), q=sig.flip().create(path=('q',))) - self.assertEqual([repr(stmt) for stmt in m._statements], [ + self.assertEqual([repr(stmt) for stmt in m._statements[None]], [ '(eq (sig q__a__0__0) (sig p__a__0__0))', ]) diff --git a/tests/test_sim.py b/tests/test_sim.py index ce776b067..5c24ae058 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -27,7 +27,7 @@ def assertStatement(self, stmt, inputs, output, reset=0): stmt = stmt(osig, *isigs) frag = Fragment() - frag.add_statements(stmt) + frag.add_statements(None, stmt) for signal in flatten(s._lhs_signals() for s in Statement.cast(stmt)): frag.add_driver(signal) @@ -1045,9 +1045,10 @@ def process(): def test_bug_595(self): dut = Module() + dummy = Signal() with dut.FSM(name="name with space"): with dut.State(0): - pass + dut.d.comb += dummy.eq(1) sim = Simulator(dut) with self.assertRaisesRegex(NameError, r"^Signal 'bench\.top\.name with space_state' contains a whitespace character$"):