Skip to content

Commit 606ebcd

Browse files
wanda-phiwhitequark
authored andcommitted
hdl._ast: Implement Mux in terms of SwitchValue.
Fixes #1075.
1 parent 466536e commit 606ebcd

File tree

6 files changed

+56
-64
lines changed

6 files changed

+56
-64
lines changed

amaranth/hdl/_ast.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,10 +1681,6 @@ def shape(self):
16811681
if self.operator == ">>":
16821682
assert not b_shape.signed
16831683
return Shape(a_shape.width, a_shape.signed)
1684-
elif len(op_shapes) == 3:
1685-
if self.operator == "m":
1686-
s_shape, a_shape, b_shape = op_shapes
1687-
return Shape._unify((a_shape, b_shape))
16881684
raise NotImplementedError # :nocov:
16891685

16901686
def _lhs_signals(self):
@@ -1715,7 +1711,7 @@ def Mux(sel, val1, val0):
17151711
Value, out
17161712
Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``.
17171713
"""
1718-
return Operator("m", [sel, val1, val0])
1714+
return SwitchValue(sel, ((0, val0), (None, val1)), src_loc_at=1)
17191715

17201716

17211717
@final

amaranth/hdl/_ir.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -874,18 +874,6 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool
874874
signed = False
875875
else:
876876
assert False # :nocov:
877-
elif len(value.operands) == 3:
878-
assert value.operator == 'm'
879-
operand_s, signed_s = self.emit_rhs(module_idx, value.operands[0])
880-
operand_a, signed_a = self.emit_rhs(module_idx, value.operands[1])
881-
operand_b, signed_b = self.emit_rhs(module_idx, value.operands[2])
882-
if len(operand_s) != 1:
883-
operand_s = self.emit_operator(module_idx, 'b', operand_s,
884-
src_loc=value.src_loc)
885-
operand_a, operand_b, signed = \
886-
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
887-
result = self.emit_operator(module_idx, 'm', operand_s, operand_a, operand_b,
888-
src_loc=value.src_loc)
889877
else:
890878
assert False # :nocov:
891879
elif isinstance(value, _ast.Slice):
@@ -901,39 +889,51 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool
901889
signed = False
902890
elif isinstance(value, _ast.SwitchValue):
903891
test, _signed = self.emit_rhs(module_idx, value.test)
904-
conds = []
905-
elems = []
906-
for patterns, elem, in value.cases:
907-
if patterns is not None:
908-
if not patterns:
909-
# Hack: empty pattern set cannot be supported by RTLIL.
910-
continue
911-
for pattern in patterns:
912-
assert len(pattern) == len(test)
913-
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
914-
src_loc=value.src_loc)
915-
net, = self.netlist.add_value_cell(1, cell)
916-
conds.append(net)
917-
else:
918-
conds.append(_nir.Net.from_const(1))
919-
elems.append(self.emit_rhs(module_idx, elem))
920-
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
921-
inputs=_nir.Value(conds),
922-
src_loc=value.src_loc)
923-
conds = self.netlist.add_value_cell(len(conds), cell)
924-
shape = _ast.Shape._unify(
925-
_ast.Shape(len(value), signed)
926-
for value, signed in elems
927-
)
928-
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
929-
assignments = [
930-
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc)
931-
for subcond, elem in zip(conds, elems)
932-
]
933-
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width),
934-
assignments=assignments, src_loc=value.src_loc)
935-
result = self.netlist.add_value_cell(shape.width, cell)
936-
signed = shape.signed
892+
if (len(value.cases) == 2 and
893+
value.cases[0][0] == ("0" * len(test),) and
894+
value.cases[1][0] is None):
895+
operand_a, signed_a = self.emit_rhs(module_idx, value.cases[1][1])
896+
operand_b, signed_b = self.emit_rhs(module_idx, value.cases[0][1])
897+
if len(test) != 1:
898+
test = self.emit_operator(module_idx, 'b', test, src_loc=value.src_loc)
899+
operand_a, operand_b, signed = \
900+
self.unify_shapes_bitwise(operand_a, signed_a, operand_b, signed_b)
901+
result = self.emit_operator(module_idx, 'm', test, operand_a, operand_b,
902+
src_loc=value.src_loc)
903+
else:
904+
conds = []
905+
elems = []
906+
for patterns, elem, in value.cases:
907+
if patterns is not None:
908+
if not patterns:
909+
# Hack: empty pattern set cannot be supported by RTLIL.
910+
continue
911+
for pattern in patterns:
912+
assert len(pattern) == len(test)
913+
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
914+
src_loc=value.src_loc)
915+
net, = self.netlist.add_value_cell(1, cell)
916+
conds.append(net)
917+
else:
918+
conds.append(_nir.Net.from_const(1))
919+
elems.append(self.emit_rhs(module_idx, elem))
920+
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
921+
inputs=_nir.Value(conds),
922+
src_loc=value.src_loc)
923+
conds = self.netlist.add_value_cell(len(conds), cell)
924+
shape = _ast.Shape._unify(
925+
_ast.Shape(len(value), signed)
926+
for value, signed in elems
927+
)
928+
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
929+
assignments = [
930+
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc)
931+
for subcond, elem in zip(conds, elems)
932+
]
933+
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width),
934+
assignments=assignments, src_loc=value.src_loc)
935+
result = self.netlist.add_value_cell(shape.width, cell)
936+
signed = shape.signed
937937
elif isinstance(value, _ast.Concat):
938938
nets = []
939939
for val in value.parts:

amaranth/hdl/_nir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,8 @@ class Operator(Cell):
559559
560560
The ternary operators are:
561561
562-
- 'm': like AST, first input needs to have width of 1, second and third operand need to have the same
563-
width as output
562+
- 'm': multiplexer, first input needs to have width of 1, second and third operand need to have
563+
the same width as output; implements arg0 ? arg1 : arg2
564564
565565
Attributes
566566
----------

amaranth/sim/_pyrtl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,6 @@ def sign(value):
244244
return f"({sign(lhs)} > {sign(rhs)})"
245245
if value.operator == ">=":
246246
return f"({sign(lhs)} >= {sign(rhs)})"
247-
elif len(value.operands) == 3:
248-
if value.operator == "m":
249-
sel, val1, val0 = value.operands
250-
return f"({sign(val1)} if {mask(sel)} else {sign(val0)})"
251247
raise NotImplementedError(f"Operator '{value.operator}' not implemented") # :nocov:
252248

253249
def on_Slice(self, value):
@@ -274,7 +270,7 @@ def on_SwitchValue(self, value):
274270
gen_test = self.emitter.def_var("test", f"{(1 << len(value.test)) - 1:#x} & {self(value.test)}")
275271
gen_value = self.emitter.def_var("rhs_switch", "0")
276272
def case_handler(patterns, elem):
277-
self.emitter.append(f"{gen_value} = {self(elem)}")
273+
self.emitter.append(f"{gen_value} = {self.sign(elem)}")
278274
self._emit_switch(gen_test, value.cases, case_handler)
279275
return gen_value
280276

tests/test_hdl_ast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def test_ne(self):
746746
def test_mux(self):
747747
s = Const(0)
748748
v1 = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6)))
749-
self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))")
749+
self.assertEqual(repr(v1), "(switch-value (const 1'd0) (case 0 (const 6'd0)) (default (const 4'd0)))")
750750
self.assertEqual(v1.shape(), unsigned(6))
751751
v2 = Mux(s, Const(0, signed(4)), Const(0, signed(6)))
752752
self.assertEqual(v2.shape(), signed(6))
@@ -758,11 +758,11 @@ def test_mux(self):
758758
def test_mux_wide(self):
759759
s = Const(0b100)
760760
v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6)))
761-
self.assertEqual(repr(v), "(m (const 3'd4) (const 4'd0) (const 6'd0))")
761+
self.assertEqual(repr(v), "(switch-value (const 3'd4) (case 000 (const 6'd0)) (default (const 4'd0)))")
762762

763763
def test_mux_bool(self):
764764
v = Mux(True, Const(0), Const(0))
765-
self.assertEqual(repr(v), "(m (const 1'd1) (const 1'd0) (const 1'd0))")
765+
self.assertEqual(repr(v), "(switch-value (const 1'd1) (case 0 (const 1'd0)) (default (const 1'd0)))")
766766

767767
def test_any(self):
768768
v = Const(0b101).any()
@@ -842,7 +842,7 @@ def test_abs(self):
842842
""")
843843
s = Signal(signed(4))
844844
self.assertRepr(abs(s), """
845-
(slice (m (>= (sig s) (const 1'd0)) (sig s) (- (sig s))) 0:4)
845+
(slice (switch-value (>= (sig s) (const 1'd0)) (case 0 (- (sig s))) (default (sig s))) 0:4)
846846
""")
847847
self.assertEqual(abs(s).shape(), unsigned(4))
848848

tests/test_hdl_xfrm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,10 @@ def test_enable_write_port(self):
406406
mem.write_port(granularity=2)
407407
f = EnableInserter(self.c1)(mem).elaborate(platform=None)
408408
self.assertRepr(f._write_ports[0]._en, """
409-
(m
409+
(switch-value
410410
(sig c1)
411-
(sig mem_w_en)
412-
(const 4'd0)
411+
(case 0 (const 4'd0))
412+
(default (sig mem_w_en))
413413
)
414414
""")
415415

0 commit comments

Comments
 (0)