Skip to content

Commit 2cf9bbf

Browse files
wanda-phiwhitequark
authored andcommitted
hdl._ast: add SwitchValue, reimplement ArrayProxy with it.
1 parent 2eb62a8 commit 2cf9bbf

File tree

7 files changed

+382
-140
lines changed

7 files changed

+382
-140
lines changed

amaranth/hdl/_ast.py

Lines changed: 130 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
__all__ = [
1919
"SyntaxError", "SyntaxWarning",
2020
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
21-
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
21+
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
2222
"Array", "ArrayProxy",
2323
"Signal", "ClockSignal", "ResetSignal",
2424
"ValueCastable", "ValueLike",
@@ -1892,6 +1892,60 @@ def __repr__(self):
18921892
return "(cat {})".format(" ".join(map(repr, self.parts)))
18931893

18941894

1895+
@final
1896+
class SwitchValue(Value):
1897+
def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
1898+
if src_loc is None:
1899+
super().__init__(src_loc_at=src_loc_at)
1900+
else:
1901+
self.src_loc = src_loc
1902+
self._test = Value.cast(test)
1903+
new_cases = []
1904+
for patterns, value in cases:
1905+
if patterns is not None:
1906+
if not isinstance(patterns, tuple):
1907+
patterns = (patterns,)
1908+
new_patterns = ()
1909+
key_mask = (1 << len(self.test)) - 1
1910+
for key in _normalize_patterns(patterns, self._test.shape()):
1911+
if isinstance(key, int):
1912+
key = to_binary(key & key_mask, len(self.test))
1913+
new_patterns = (*new_patterns, key)
1914+
else:
1915+
new_patterns = None
1916+
new_cases.append((new_patterns, Value.cast(value)))
1917+
self._cases = tuple(new_cases)
1918+
1919+
@property
1920+
def test(self):
1921+
return self._test
1922+
1923+
@property
1924+
def cases(self):
1925+
return self._cases
1926+
1927+
def shape(self):
1928+
return Shape._unify(value.shape() for _patterns, value in self._cases)
1929+
1930+
def _lhs_signals(self):
1931+
return union((value._lhs_signals() for _patterns, value in self.cases), start=SignalSet())
1932+
1933+
def _rhs_signals(self):
1934+
signals = union((value._rhs_signals() for _patterns, value in self.cases), start=SignalSet())
1935+
return self.test._rhs_signals() | signals
1936+
1937+
def __repr__(self):
1938+
def case_repr(patterns, value):
1939+
if patterns is None:
1940+
return f"(default {value!r})"
1941+
elif len(patterns) == 1:
1942+
return f"(case {patterns[0]} {value!r})"
1943+
else:
1944+
return "(case ({}) {!r})".format(" ".join(patterns), value)
1945+
case_reprs = (case_repr(patterns, value) for patterns, value in self.cases)
1946+
return "(switch-value {!r} {})".format(self.test, " ".join(case_reprs))
1947+
1948+
18951949
class _SignalMeta(ABCMeta):
18961950
def __call__(cls, shape=None, src_loc_at=0, **kwargs):
18971951
signal = super().__call__(shape, **kwargs, src_loc_at=src_loc_at + 1)
@@ -2356,10 +2410,17 @@ def __repr__(self):
23562410
", ".join(map(repr, self._inner)))
23572411

23582412

2413+
def _proxy_value(name):
2414+
@functools.wraps(getattr(Value, name))
2415+
def inner(self, *args, **kwargs):
2416+
return getattr(Value.cast(self), name)(*args, **kwargs)
2417+
return inner
2418+
2419+
23592420
@final
2360-
class ArrayProxy(Value):
2421+
class ArrayProxy(ValueCastable):
23612422
def __init__(self, elems, index, *, src_loc_at=0):
2362-
super().__init__(src_loc_at=1 + src_loc_at)
2423+
self.src_loc = tracer.get_src_loc(1 + src_loc_at)
23632424
self._elems = elems
23642425
self._index = Value.cast(index)
23652426

@@ -2385,19 +2446,73 @@ def shape(self):
23852446
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
23862447
return Shape._unify(elem.shape() for elem in self._iter_as_values())
23872448

2388-
def _lhs_signals(self):
2389-
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
2390-
start=SignalSet())
2391-
return signals
2449+
def as_value(self):
2450+
return SwitchValue(
2451+
self._index,
2452+
(
2453+
(index, value)
2454+
for index, value in enumerate(self._elems)
2455+
if index in range(1 << len(self._index))
2456+
),
2457+
src_loc=self.src_loc,
2458+
)
23922459

2393-
def _rhs_signals(self):
2394-
signals = union((elem._rhs_signals() for elem in self._iter_as_values()),
2395-
start=SignalSet())
2396-
return self.index._rhs_signals() | signals
2460+
def eq(self, value, *, src_loc_at=0):
2461+
return self.as_value().eq(value, src_loc_at=1 + src_loc_at)
23972462

23982463
def __repr__(self):
23992464
return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index)
24002465

2466+
as_signed = _proxy_value("as_signed")
2467+
as_unsigned = _proxy_value("as_unsigned")
2468+
__len__ = _proxy_value("__len__")
2469+
__bool__ = _proxy_value("__bool__")
2470+
bool = _proxy_value("bool")
2471+
__pos__ = _proxy_value("__pos__")
2472+
__neg__ = _proxy_value("__neg__")
2473+
__add__ = _proxy_value("__add__")
2474+
__radd__ = _proxy_value("__radd__")
2475+
__sub__ = _proxy_value("__sub__")
2476+
__rsub__ = _proxy_value("__rsub__")
2477+
__mul__ = _proxy_value("__mul__")
2478+
__rmul__ = _proxy_value("__rmul__")
2479+
__floordiv__ = _proxy_value("__floordiv__")
2480+
__rfloordiv__ = _proxy_value("__rfloordiv__")
2481+
__mod__ = _proxy_value("__mod__")
2482+
__rmod__ = _proxy_value("__rmod__")
2483+
__eq__ = _proxy_value("__eq__")
2484+
__ne__ = _proxy_value("__ne__")
2485+
__lt__ = _proxy_value("__lt__")
2486+
__le__ = _proxy_value("__le__")
2487+
__gt__ = _proxy_value("__gt__")
2488+
__ge__ = _proxy_value("__ge__")
2489+
__abs__ = _proxy_value("__abs__")
2490+
__invert__ = _proxy_value("__invert__")
2491+
__and__ = _proxy_value("__and__")
2492+
__rand__ = _proxy_value("__rand__")
2493+
__or__ = _proxy_value("__or__")
2494+
__ror__ = _proxy_value("__ror__")
2495+
__xor__ = _proxy_value("__xor__")
2496+
__rxor__ = _proxy_value("__rxor__")
2497+
any = _proxy_value("any")
2498+
all = _proxy_value("all")
2499+
xor = _proxy_value("xor")
2500+
implies = _proxy_value("implies")
2501+
__lshift__ = _proxy_value("__lshift__")
2502+
__rlshift__ = _proxy_value("__rlshift__")
2503+
__rshift__ = _proxy_value("__rshift__")
2504+
__rrshift__ = _proxy_value("__rrshift__")
2505+
shift_left = _proxy_value("shift_left")
2506+
shift_right = _proxy_value("shift_right")
2507+
rotate_left = _proxy_value("rotate_left")
2508+
rotate_right = _proxy_value("rotate_right")
2509+
__contains__ = _proxy_value("__contains__")
2510+
bit_select = _proxy_value("bit_select")
2511+
word_select = _proxy_value("word_select")
2512+
replicate = _proxy_value("replicate")
2513+
matches = _proxy_value("matches")
2514+
__format__ = _proxy_value("__format__")
2515+
24012516

24022517
@final
24032518
class Initial(Value):
@@ -2772,7 +2887,7 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
27722887
self.src_loc = src_loc
27732888

27742889
self._test = Value.cast(test)
2775-
self._cases = []
2890+
new_cases = []
27762891
for patterns, stmts, case_src_loc in cases:
27772892
if patterns is not None:
27782893
# Map: key -> (key,); (key...) -> (key...)
@@ -2787,10 +2902,8 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
27872902
new_patterns = (*new_patterns, key)
27882903
else:
27892904
new_patterns = None
2790-
if not isinstance(stmts, Iterable):
2791-
stmts = [stmts]
2792-
self._cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
2793-
self._cases = tuple(self._cases)
2905+
new_cases.append((new_patterns, Statement.cast(stmts), case_src_loc))
2906+
self._cases = tuple(new_cases)
27942907

27952908
@property
27962909
def test(self):
@@ -2816,7 +2929,7 @@ def case_repr(patterns, stmts):
28162929
return f"(case {patterns[0]} {stmts_repr})"
28172930
else:
28182931
return "(case ({}) {})".format(" ".join(patterns), stmts_repr)
2819-
case_reprs = [case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases]
2932+
case_reprs = (case_repr(patterns, stmts) for patterns, stmts, _src_loc in self.cases)
28202933
return "(switch {!r} {})".format(self.test, " ".join(case_reprs))
28212934

28222935

amaranth/hdl/_ir.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -885,30 +885,39 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool
885885
stride=value.stride, offset=offset, src_loc=value.src_loc)
886886
result = self.netlist.add_value_cell(value.width, cell)
887887
signed = False
888-
elif isinstance(value, _ast.ArrayProxy):
889-
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
888+
elif isinstance(value, _ast.SwitchValue):
889+
test, _signed = self.emit_rhs(module_idx, value.test)
890+
conds = []
891+
elems = []
892+
for patterns, elem, in value.cases:
893+
if patterns is not None:
894+
if not patterns:
895+
# Hack: empty pattern set cannot be supported by RTLIL.
896+
continue
897+
for pattern in patterns:
898+
assert len(pattern) == len(test)
899+
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
900+
src_loc=value.src_loc)
901+
net, = self.netlist.add_value_cell(1, cell)
902+
conds.append(net)
903+
else:
904+
conds.append(_nir.Net.from_const(1))
905+
elems.append(self.emit_rhs(module_idx, elem))
906+
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
907+
inputs=_nir.Value(conds),
908+
src_loc=value.src_loc)
909+
conds = self.netlist.add_value_cell(len(conds), cell)
890910
shape = _ast.Shape._unify(
891911
_ast.Shape(len(value), signed)
892912
for value, signed in elems
893913
)
894914
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
895-
index, _signed = self.emit_rhs(module_idx, value.index)
896-
conds = []
897-
for case_index in range(len(elems)):
898-
cell = _nir.Matches(module_idx, value=index,
899-
patterns=(to_binary(case_index, len(index)),),
900-
src_loc=value.src_loc)
901-
subcond, = self.netlist.add_value_cell(1, cell)
902-
conds.append(subcond)
903-
conds = _nir.Value(conds)
904-
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1), inputs=conds, src_loc=value.src_loc)
905-
conds = self.netlist.add_value_cell(len(conds), cell)
906915
assignments = [
907-
_nir.Assignment(cond=cond, start=0, value=elem, src_loc=value.src_loc)
908-
for cond, elem in zip(conds, elems)
916+
_nir.Assignment(cond=subcond, start=0, value=elem, src_loc=value.src_loc)
917+
for subcond, elem in zip(conds, elems)
909918
]
910-
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
911-
src_loc=value.src_loc)
919+
cell = _nir.AssignmentList(module_idx, default=_nir.Value.from_const(0, shape.width),
920+
assignments=assignments, src_loc=value.src_loc)
912921
result = self.netlist.add_value_cell(shape.width, cell)
913922
signed = shape.signed
914923
elif isinstance(value, _ast.Concat):
@@ -1017,19 +1026,29 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
10171026
else:
10181027
subrhs = rhs
10191028
self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc)
1020-
elif isinstance(lhs, _ast.ArrayProxy):
1021-
index, _signed = self.emit_rhs(module_idx, lhs.index)
1029+
elif isinstance(lhs, _ast.SwitchValue):
1030+
test, _signed = self.emit_rhs(module_idx, lhs.test)
10221031
conds = []
1023-
for case_index in range(len(lhs.elems)):
1024-
cell = _nir.Matches(module_idx, value=index,
1025-
patterns=(to_binary(case_index, len(index)),),
1026-
src_loc=lhs.src_loc)
1027-
subcond, = self.netlist.add_value_cell(1, cell)
1028-
conds.append(subcond)
1032+
elems = []
1033+
for patterns, elem in lhs.cases:
1034+
if patterns is not None:
1035+
if not patterns:
1036+
# Hack: empty pattern set cannot be supported by RTLIL.
1037+
continue
1038+
for pattern in patterns:
1039+
assert len(pattern) == len(test)
1040+
cell = _nir.Matches(module_idx, value=test, patterns=patterns,
1041+
src_loc=lhs.src_loc)
1042+
net, = self.netlist.add_value_cell(1, cell)
1043+
conds.append(net)
1044+
else:
1045+
conds.append(_nir.Net.from_const(1))
1046+
elems.append(elem)
10291047
conds = _nir.Value(conds)
1030-
cell = _nir.PriorityMatch(module_idx, en=cond, inputs=conds, src_loc=lhs.src_loc)
1048+
cell = _nir.PriorityMatch(module_idx, en=_nir.Net.from_const(1),
1049+
inputs=conds, src_loc=lhs.src_loc)
10311050
conds = self.netlist.add_value_cell(len(conds), cell)
1032-
for subcond, val in zip(conds, lhs.elems):
1051+
for subcond, val in zip(conds, elems):
10331052
self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc)
10341053
elif isinstance(lhs, _ast.Operator):
10351054
assert lhs.operator in ('u', 's')

amaranth/hdl/_xfrm.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def on_Concat(self, value):
5858
pass # :nocov:
5959

6060
@abstractmethod
61-
def on_ArrayProxy(self, value):
61+
def on_SwitchValue(self, value):
6262
pass # :nocov:
6363

6464
@abstractmethod
@@ -90,8 +90,8 @@ def on_value(self, value):
9090
new_value = self.on_Part(value)
9191
elif type(value) is Concat:
9292
new_value = self.on_Concat(value)
93-
elif type(value) is ArrayProxy:
94-
new_value = self.on_ArrayProxy(value)
93+
elif type(value) is SwitchValue:
94+
new_value = self.on_SwitchValue(value)
9595
elif type(value) is Initial:
9696
new_value = self.on_Initial(value)
9797
else:
@@ -133,9 +133,8 @@ def on_Part(self, value):
133133
def on_Concat(self, value):
134134
return Concat(self.on_value(o) for o in value.parts)
135135

136-
def on_ArrayProxy(self, value):
137-
return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
138-
self.on_value(value.index))
136+
def on_SwitchValue(self, value):
137+
return SwitchValue(self.on_value(value.test), [(patterns, self.on_value(val)) for patterns, val in value.cases])
139138

140139
def on_Initial(self, value):
141140
return value
@@ -399,10 +398,10 @@ def on_Concat(self, value):
399398
for o in value.parts:
400399
self.on_value(o)
401400

402-
def on_ArrayProxy(self, value):
403-
for elem in value._iter_as_values():
404-
self.on_value(elem)
405-
self.on_value(value.index)
401+
def on_SwitchValue(self, value):
402+
self.on_value(value.test)
403+
for patterns, val in value.cases:
404+
self.on_value(val)
406405

407406
def on_Initial(self, value):
408407
pass

0 commit comments

Comments
 (0)