Skip to content

Commit 2eb62a8

Browse files
wanda-phiwhitequark
authored andcommitted
hdl._ast: change Switch to operate on list of cases.
1 parent cd6cbd7 commit 2eb62a8

File tree

8 files changed

+95
-90
lines changed

8 files changed

+95
-90
lines changed

amaranth/back/rtlil.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,8 @@ def emit_assignments(case, cond):
638638
assert isinstance(matches_cell, _nir.Matches)
639639
assert test == matches_cell.value
640640
patterns = matches_cell.patterns
641+
# RTLIL cannot support empty pattern sets.
642+
assert patterns
641643
with switch.case(*patterns) as subcase:
642644
emit_assignments(subcase, subcond)
643645
emitted_switch = True

amaranth/hdl/_ast.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,38 +2763,34 @@ def resolve(self):
27632763

27642764
@final
27652765
class Switch(Statement):
2766-
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={}):
2766+
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
27672767
if src_loc is None:
27682768
super().__init__(src_loc_at=src_loc_at)
27692769
else:
27702770
# Switch is a bit special in terms of location tracking because it is usually created
27712771
# long after the control has left the statement that directly caused its creation.
27722772
self.src_loc = src_loc
2773-
# Switch is also a bit special in that its parts also have location information. It can't
2774-
# be automatically traced, so whatever constructs a Switch may optionally provide it.
2775-
self.case_src_locs = {}
27762773

27772774
self._test = Value.cast(test)
2778-
self._cases = OrderedDict()
2779-
for orig_keys, stmts in cases.items():
2780-
# Map: None -> (); key -> (key,); (key...) -> (key...)
2781-
keys = orig_keys
2782-
if keys is None:
2783-
keys = ()
2784-
if not isinstance(keys, tuple):
2785-
keys = (keys,)
2786-
# Map: 2 -> "0010"; "0010" -> "0010"
2787-
new_keys = ()
2788-
key_mask = (1 << len(self.test)) - 1
2789-
for key in _normalize_patterns(keys, self._test.shape()):
2790-
if isinstance(key, int):
2791-
key = to_binary(key & key_mask, len(self.test))
2792-
new_keys = (*new_keys, key)
2775+
self._cases = []
2776+
for patterns, stmts, case_src_loc in cases:
2777+
if patterns is not None:
2778+
# Map: key -> (key,); (key...) -> (key...)
2779+
if not isinstance(patterns, tuple):
2780+
patterns = (patterns,)
2781+
# Map: 2 -> "0010"; "0010" -> "0010"
2782+
new_patterns = ()
2783+
key_mask = (1 << len(self.test)) - 1
2784+
for key in _normalize_patterns(patterns, self._test.shape()):
2785+
if isinstance(key, int):
2786+
key = to_binary(key & key_mask, len(self.test))
2787+
new_patterns = (*new_patterns, key)
2788+
else:
2789+
new_patterns = None
27932790
if not isinstance(stmts, Iterable):
27942791
stmts = [stmts]
2795-
self._cases[new_keys] = Statement.cast(stmts)
2796-
if orig_keys in case_src_locs:
2797-
self.case_src_locs[new_keys] = case_src_locs[orig_keys]
2792+
self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
2793+
self._cases = tuple(self._cases)
27982794

27992795
@property
28002796
def test(self):
@@ -2805,22 +2801,22 @@ def cases(self):
28052801
return self._cases
28062802

28072803
def _lhs_signals(self):
2808-
return union((s._lhs_signals() for s in self.cases.values()), start=SignalSet())
2804+
return union((stmts._lhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
28092805

28102806
def _rhs_signals(self):
2811-
signals = union((s._rhs_signals() for s in self.cases.values()), start=SignalSet())
2807+
signals = union((stmts._rhs_signals() for _patterns, stmts, _src_loc in self.cases), start=SignalSet())
28122808
return self.test._rhs_signals() | signals
28132809

28142810
def __repr__(self):
2815-
def case_repr(keys, stmts):
2811+
def case_repr(patterns, stmts):
28162812
stmts_repr = " ".join(map(repr, stmts))
2817-
if keys == ():
2813+
if patterns is None:
28182814
return f"(default {stmts_repr})"
2819-
elif len(keys) == 1:
2820-
return f"(case {keys[0]} {stmts_repr})"
2815+
elif len(patterns) == 1:
2816+
return f"(case {patterns[0]} {stmts_repr})"
28212817
else:
2822-
return "(case ({}) {})".format(" ".join(keys), stmts_repr)
2823-
case_reprs = [case_repr(keys, stmts) for keys, stmts in self.cases.items()]
2818+
return "(case ({}) {})".format(" ".join(patterns), stmts_repr)
2819+
case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases]
28242820
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
28252821

28262822

amaranth/hdl/_dsl.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,11 @@ def resolve_statement(stmt):
169169
elif isinstance(stmt, Switch):
170170
return Switch(
171171
test=stmt.test,
172-
cases=OrderedDict(
173-
(patterns, resolve_statements(stmts))
174-
for patterns, stmts in stmt.cases.items()
175-
),
172+
cases=[
173+
(patterns, resolve_statements(stmts), src_loc)
174+
for patterns, stmts, src_loc in stmt.cases
175+
],
176176
src_loc=stmt.src_loc,
177-
case_src_locs=stmt.case_src_locs,
178177
)
179178
elif isinstance(stmt, (Assign, Property, Print)):
180179
return stmt
@@ -318,9 +317,9 @@ def Switch(self, test):
318317
self._check_context("Switch", context=None)
319318
switch_data = self._set_ctrl("Switch", {
320319
"test": Value.cast(test),
321-
"cases": OrderedDict(),
320+
"cases": [],
322321
"src_loc": tracer.get_src_loc(src_loc_at=1),
323-
"case_src_locs": {},
322+
"got_default": False,
324323
})
325324
try:
326325
self._ctrl_context = "Switch"
@@ -336,7 +335,7 @@ def Case(self, *patterns):
336335
self._check_context("Case", context="Switch")
337336
src_loc = tracer.get_src_loc(src_loc_at=1)
338337
switch_data = self._get_ctrl("Switch")
339-
if () in switch_data["cases"]:
338+
if switch_data["got_default"]:
340339
warnings.warn("A case defined after the default case will never be active",
341340
SyntaxWarning, stacklevel=3)
342341
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
@@ -345,12 +344,7 @@ def Case(self, *patterns):
345344
self._ctrl_context = None
346345
yield
347346
self._flush_ctrl()
348-
# If none of the provided cases can possibly be true, omit this branch completely.
349-
# Likewise, omit this branch if another branch with this exact set of patterns already
350-
# exists (since otherwise we'd overwrite the previous branch's slot in the dict).
351-
if new_patterns and new_patterns not in switch_data["cases"]:
352-
switch_data["cases"][new_patterns] = self._statements
353-
switch_data["case_src_locs"][new_patterns] = src_loc
347+
switch_data["cases"].append((new_patterns, self._statements, src_loc))
354348
finally:
355349
self._ctrl_context = "Switch"
356350
self._statements = _outer_case
@@ -360,17 +354,16 @@ def Default(self):
360354
self._check_context("Default", context="Switch")
361355
src_loc = tracer.get_src_loc(src_loc_at=1)
362356
switch_data = self._get_ctrl("Switch")
363-
if () in switch_data["cases"]:
357+
if switch_data["got_default"]:
364358
warnings.warn("A case defined after the default case will never be active",
365359
SyntaxWarning, stacklevel=3)
366360
try:
367361
_outer_case, self._statements = self._statements, {}
368362
self._ctrl_context = None
369363
yield
370364
self._flush_ctrl()
371-
if () not in switch_data["cases"]:
372-
switch_data["cases"][()] = self._statements
373-
switch_data["case_src_locs"][()] = src_loc
365+
switch_data["cases"].append((None, self._statements, src_loc))
366+
switch_data["got_default"] = True
374367
finally:
375368
self._ctrl_context = "Switch"
376369
self._statements = _outer_case
@@ -471,8 +464,8 @@ def _pop_ctrl(self):
471464
domains[domain] = None
472465

473466
for domain in domains:
474-
tests, cases = [], OrderedDict()
475-
for if_test, if_case in zip(if_tests + [None], if_bodies):
467+
tests, cases = [], []
468+
for if_test, if_case, if_src_loc in zip(if_tests + [None], if_bodies, if_src_locs):
476469
if if_test is not None:
477470
if len(if_test) != 1:
478471
if_test = if_test.bool()
@@ -482,27 +475,26 @@ def _pop_ctrl(self):
482475
match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-")
483476
else:
484477
match = None
485-
cases[match] = if_case.get(domain, [])
478+
cases.append((match, if_case.get(domain, []), if_src_loc))
486479

487480
self._statements.setdefault(domain, []).append(Switch(Cat(tests), cases,
488-
src_loc=src_loc, case_src_locs=dict(zip(cases, if_src_locs))))
481+
src_loc=src_loc))
489482

490483
if name == "Switch":
491484
switch_test, switch_cases = data["test"], data["cases"]
492-
switch_case_src_locs = data["case_src_locs"]
493485

494486
domains = {}
495-
for stmts in switch_cases.values():
487+
for _patterns, stmts, _src_loc in switch_cases:
496488
for domain in stmts:
497489
domains[domain] = None
498490

499491
for domain in domains:
500-
domain_cases = OrderedDict()
501-
for patterns, stmts in switch_cases.items():
502-
domain_cases[patterns] = stmts.get(domain, [])
492+
domain_cases = []
493+
for patterns, stmts, case_src_loc in switch_cases:
494+
domain_cases.append((patterns, stmts.get(domain, []), case_src_loc))
503495

504496
self._statements.setdefault(domain, []).append(Switch(switch_test, domain_cases,
505-
src_loc=src_loc, case_src_locs=switch_case_src_locs))
497+
src_loc=src_loc))
506498

507499
if name == "FSM":
508500
fsm_name, fsm_init, fsm_encoding, fsm_decoding, fsm_states, fsm_ongoing = \
@@ -536,9 +528,11 @@ def _pop_ctrl(self):
536528
domain_states[state] = stmts.get(domain, [])
537529

538530
self._statements.setdefault(domain, []).append(Switch(fsm_signal,
539-
OrderedDict((fsm_encoding[name], stmts) for name, stmts in domain_states.items()),
540-
src_loc=src_loc, case_src_locs={fsm_encoding[name]: fsm_state_src_locs[name]
541-
for name in fsm_states}))
531+
[
532+
(fsm_encoding[name], stmts, fsm_state_src_locs[name])
533+
for name, stmts in domain_states.items()
534+
],
535+
src_loc=src_loc))
542536

543537
def _add_statement(self, assigns, domain, depth):
544538
while len(self._ctrl_stack) > self.domain._depth:

amaranth/hdl/_ir.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,20 +1110,25 @@ def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
11101110
elif isinstance(stmt, _ast.Switch):
11111111
test, _signed = self.emit_rhs(module_idx, stmt.test)
11121112
conds = []
1113-
for patterns in stmt.cases:
1114-
if patterns:
1113+
case_stmts = []
1114+
for patterns, stmts, case_src_loc in stmt.cases:
1115+
if patterns is not None:
1116+
if not patterns:
1117+
# Hack: empty pattern set cannot be supported by RTLIL.
1118+
continue
11151119
for pattern in patterns:
11161120
assert len(pattern) == len(test)
11171121
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
1118-
src_loc=stmt.case_src_locs.get(patterns))
1122+
src_loc=case_src_loc)
11191123
net, = self.netlist.add_value_cell(1, cell)
11201124
conds.append(net)
11211125
else:
11221126
conds.append(_nir.Net.from_const(1))
1127+
case_stmts.append(stmts)
11231128
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=_nir.Value(conds),
11241129
src_loc=stmt.src_loc)
11251130
conds = self.netlist.add_value_cell(len(conds), cell)
1126-
for subcond, substmts in zip(conds, stmt.cases.values()):
1131+
for subcond, substmts in zip(conds, case_stmts):
11271132
for substmt in substmts:
11281133
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
11291134
else:

amaranth/hdl/_xfrm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,6 @@ def on_statement(self, stmt):
183183
new_stmt = self.on_unknown_statement(stmt)
184184
if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt):
185185
new_stmt.src_loc = stmt.src_loc
186-
if isinstance(new_stmt, Switch) and isinstance(stmt, Switch):
187-
new_stmt.case_src_locs = stmt.case_src_locs
188186
if isinstance(new_stmt, (Print, Property)):
189187
new_stmt._MustUse__used = True
190188
return new_stmt
@@ -221,7 +219,7 @@ def on_Property(self, stmt):
221219
return Property(stmt.kind, self.on_value(stmt.test), message)
222220

223221
def on_Switch(self, stmt):
224-
cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
222+
cases = [(k, self.on_statement(s), l) for k, s, l in stmt.cases]
225223
return Switch(self.on_value(stmt.test), cases)
226224

227225
def on_statements(self, stmts):
@@ -429,7 +427,7 @@ def on_Property(self, stmt):
429427

430428
def on_Switch(self, stmt):
431429
self.on_value(stmt.test)
432-
for stmts in stmt.cases.values():
430+
for _patterns, stmts, _src_loc in stmt.cases:
433431
self.on_statement(stmts)
434432

435433
def on_statements(self, stmts):
@@ -624,15 +622,15 @@ def __call__(self, value, *, src_loc_at=0):
624622
class ResetInserter(_ControlInserter):
625623
def _insert_control(self, fragment, domain, signals):
626624
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
627-
fragment.add_statements(domain, Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc))
625+
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))
628626

629627

630628
class EnableInserter(_ControlInserter):
631629
def _insert_control(self, fragment, domain, signals):
632630
if domain in fragment.statements:
633631
fragment.statements[domain] = _StatementList([Switch(
634632
self.controls[domain],
635-
{1: fragment.statements[domain]},
633+
[(1, fragment.statements[domain], None)],
636634
src_loc=self.src_loc,
637635
)])
638636

amaranth/sim/_pyrtl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,12 @@ def on_Assign(self, stmt):
396396
def on_Switch(self, stmt):
397397
gen_test_value = self.rhs(stmt.test) # check for oversized value before generating mask
398398
gen_test = self.emitter.def_var("test", f"{(1 << len(stmt.test)) - 1:#x} & {gen_test_value}")
399-
for index, (patterns, stmts) in enumerate(stmt.cases.items()):
399+
for index, (patterns, stmts, _src_loc) in enumerate(stmt.cases):
400400
gen_checks = []
401-
if not patterns:
401+
if patterns is None:
402402
gen_checks.append(f"True")
403+
elif not patterns:
404+
gen_checks.append(f"False")
403405
else:
404406
for pattern in patterns:
405407
if "-" in pattern:

tests/test_hdl_ast.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,38 +1687,38 @@ def test_construct_wrong(self):
16871687

16881688
class SwitchTestCase(FHDLTestCase):
16891689
def test_default_case(self):
1690-
s = Switch(Const(0), {None: []})
1691-
self.assertEqual(s.cases, {(): []})
1690+
s = Switch(Const(0), [(None, [], None)])
1691+
self.assertEqual(s.cases, ((None, [], None),))
16921692

16931693
def test_int_case(self):
1694-
s = Switch(Const(0, 8), {10: []})
1695-
self.assertEqual(s.cases, {("00001010",): []})
1694+
s = Switch(Const(0, 8), [(10, [], None)])
1695+
self.assertEqual(s.cases, ((("00001010",), [], None),))
16961696

16971697
def test_int_neg_case(self):
1698-
s = Switch(Const(0, signed(8)), {-10: []})
1699-
self.assertEqual(s.cases, {("11110110",): []})
1698+
s = Switch(Const(0, signed(8)), [(-10, [], None)])
1699+
self.assertEqual(s.cases, ((("11110110",), [], None),))
17001700

17011701
def test_int_zero_width(self):
1702-
s = Switch(Const(0, 0), {0: []})
1703-
self.assertEqual(s.cases, {("",): []})
1702+
s = Switch(Const(0, 0), [(0, [], None)])
1703+
self.assertEqual(s.cases, ((("",), [], None),))
17041704

17051705
def test_int_zero_width_enum(self):
17061706
class ZeroEnum(Enum):
17071707
A = 0
1708-
s = Switch(Const(0, 0), {ZeroEnum.A: []})
1709-
self.assertEqual(s.cases, {("",): []})
1708+
s = Switch(Const(0, 0), [(ZeroEnum.A, [], None)])
1709+
self.assertEqual(s.cases, ((("",), [], None),))
17101710

17111711
def test_enum_case(self):
1712-
s = Switch(Const(0, UnsignedEnum), {UnsignedEnum.FOO: []})
1713-
self.assertEqual(s.cases, {("01",): []})
1712+
s = Switch(Const(0, UnsignedEnum), [(UnsignedEnum.FOO, [], None)])
1713+
self.assertEqual(s.cases, ((("01",), [], None),))
17141714

17151715
def test_str_case(self):
1716-
s = Switch(Const(0, 8), {"0000 11\t01": []})
1717-
self.assertEqual(s.cases, {("00001101",): []})
1716+
s = Switch(Const(0, 8), [("0000 11\t01", [], None)])
1717+
self.assertEqual(s.cases, ((("00001101",), [], None),))
17181718

17191719
def test_two_cases(self):
1720-
s = Switch(Const(0, 8), {("00001111", 123): []})
1721-
self.assertEqual(s.cases, {("00001111", "01111011"): []})
1720+
s = Switch(Const(0, 8), [(("00001111", 123), [], None)])
1721+
self.assertEqual(s.cases, ((("00001111", "01111011"), [], None),))
17221722

17231723

17241724
class IOValueTestCase(FHDLTestCase):

tests/test_hdl_dsl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def test_Switch_empty_Case(self):
411411
(
412412
(switch (sig w1)
413413
(case 0011 (eq (sig c1) (const 1'd1)))
414+
(case () (eq (sig c2) (const 1'd1)))
414415
)
415416
)
416417
""")
@@ -500,7 +501,14 @@ class Color(Enum):
500501
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
501502
with m.Case(Color.RED):
502503
m.d.comb += dummy.eq(0)
503-
self.assertEqual(m._statements, {})
504+
self.assertRepr(m._statements["comb"], """
505+
(
506+
(switch (sig w1)
507+
(case () (eq (sig dummy) (const 1'd0)))
508+
(case () (eq (sig dummy) (const 1'd0)))
509+
)
510+
)
511+
""")
504512

505513
def test_Switch_zero_width(self):
506514
m = Module()

0 commit comments

Comments
 (0)