Skip to content

Commit 161b014

Browse files
wanda-phiwhitequark
authored andcommitted
hdl._ast, hdl._ir: Deduplicate shape unification logic. NFC
1 parent 31a12c0 commit 161b014

File tree

2 files changed

+41
-61
lines changed

2 files changed

+41
-61
lines changed

amaranth/hdl/_ast.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,28 @@ def __eq__(self, other):
159159
return (isinstance(other, Shape) and
160160
self.width == other.width and self.signed == other.signed)
161161

162+
@staticmethod
163+
def _unify(shapes):
164+
"""Returns the minimal shape that contains all shapes from the list.
165+
166+
If no shapes passed in, returns unsigned(0).
167+
"""
168+
unsigned_width = signed_width = 0
169+
has_signed = False
170+
for shape in shapes:
171+
assert isinstance(shape, Shape)
172+
if shape.signed:
173+
has_signed = True
174+
signed_width = max(signed_width, shape.width)
175+
else:
176+
unsigned_width = max(unsigned_width, shape.width)
177+
# If all shapes unsigned, simply take max.
178+
if not has_signed:
179+
return unsigned(unsigned_width)
180+
# Otherwise, result is signed. All unsigned inputs, if any,
181+
# need to be converted to signed by adding a zero bit.
182+
return signed(max(signed_width, unsigned_width + 1))
183+
162184

163185
def unsigned(width):
164186
"""Returns :py:`Shape(width, signed=False)`."""
@@ -1524,20 +1546,6 @@ def operands(self):
15241546
return self._operands
15251547

15261548
def shape(self):
1527-
def _bitwise_binary_shape(a_shape, b_shape):
1528-
if not a_shape.signed and not b_shape.signed:
1529-
# both operands unsigned
1530-
return unsigned(max(a_shape.width, b_shape.width))
1531-
elif a_shape.signed and b_shape.signed:
1532-
# both operands signed
1533-
return signed(max(a_shape.width, b_shape.width))
1534-
elif not a_shape.signed and b_shape.signed:
1535-
# first operand unsigned (add sign bit), second operand signed
1536-
return signed(max(a_shape.width + 1, b_shape.width))
1537-
else:
1538-
# first signed, second operand unsigned (add sign bit)
1539-
return signed(max(a_shape.width, b_shape.width + 1))
1540-
15411549
op_shapes = list(map(lambda x: x.shape(), self.operands))
15421550
if len(op_shapes) == 1:
15431551
a_shape, = op_shapes
@@ -1554,10 +1562,10 @@ def _bitwise_binary_shape(a_shape, b_shape):
15541562
elif len(op_shapes) == 2:
15551563
a_shape, b_shape = op_shapes
15561564
if self.operator == "+":
1557-
o_shape = _bitwise_binary_shape(*op_shapes)
1565+
o_shape = Shape._unify(op_shapes)
15581566
return Shape(o_shape.width + 1, o_shape.signed)
15591567
if self.operator == "-":
1560-
o_shape = _bitwise_binary_shape(*op_shapes)
1568+
o_shape = Shape._unify(op_shapes)
15611569
return Shape(o_shape.width + 1, True)
15621570
if self.operator == "*":
15631571
return Shape(a_shape.width + b_shape.width, a_shape.signed or b_shape.signed)
@@ -1568,7 +1576,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
15681576
if self.operator in ("<", "<=", "==", "!=", ">", ">="):
15691577
return Shape(1, False)
15701578
if self.operator in ("&", "|", "^"):
1571-
return _bitwise_binary_shape(*op_shapes)
1579+
return Shape._unify(op_shapes)
15721580
if self.operator == "<<":
15731581
assert not b_shape.signed
15741582
return Shape(a_shape.width + 2 ** b_shape.width - 1, a_shape.signed)
@@ -1578,7 +1586,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
15781586
elif len(op_shapes) == 3:
15791587
if self.operator == "m":
15801588
s_shape, a_shape, b_shape = op_shapes
1581-
return _bitwise_binary_shape(a_shape, b_shape)
1589+
return Shape._unify((a_shape, b_shape))
15821590
raise NotImplementedError # :nocov:
15831591

15841592
def _lhs_signals(self):
@@ -2254,27 +2262,9 @@ def _iter_as_values(self):
22542262
return (Value.cast(elem) for elem in self.elems)
22552263

22562264
def shape(self):
2257-
unsigned_width = signed_width = 0
2258-
has_unsigned = has_signed = False
2259-
for elem_shape in (elem.shape() for elem in self._iter_as_values()):
2260-
if elem_shape.signed:
2261-
has_signed = True
2262-
signed_width = max(signed_width, elem_shape.width)
2263-
else:
2264-
has_unsigned = True
2265-
unsigned_width = max(unsigned_width, elem_shape.width)
22662265
# The shape of the proxy must be such that it preserves the mathematical value of the array
22672266
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
2268-
# To ensure this holds, if the array contains both signed and unsigned values, make sure
2269-
# that every unsigned value is zero-extended by at least one bit.
2270-
if has_signed and has_unsigned and unsigned_width >= signed_width:
2271-
# Array contains both signed and unsigned values, and at least one of the unsigned
2272-
# values won't be zero-extended otherwise.
2273-
return signed(unsigned_width + 1)
2274-
else:
2275-
# Array contains values of the same signedness, or else all of the unsigned values
2276-
# are zero-extended.
2277-
return Shape(max(unsigned_width, signed_width), has_signed)
2267+
return Shape._unify(elem.shape() for elem in self._iter_as_values())
22782268

22792269
def _lhs_signals(self):
22802270
signals = union((elem._lhs_signals() for elem in self._iter_as_values()),

amaranth/hdl/_ir.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -677,16 +677,13 @@ def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src
677677

678678
def unify_shapes_bitwise(self,
679679
operand_a: _nir.Value, signed_a: bool, operand_b: _nir.Value, signed_b: bool):
680-
if signed_a == signed_b:
681-
width = max(len(operand_a), len(operand_b))
682-
elif signed_a:
683-
width = max(len(operand_a), len(operand_b) + 1)
684-
else: # signed_b
685-
width = max(len(operand_a) + 1, len(operand_b))
686-
operand_a = self.extend(operand_a, signed_a, width)
687-
operand_b = self.extend(operand_b, signed_b, width)
688-
signed = signed_a or signed_b
689-
return (operand_a, operand_b, signed)
680+
shape = _ast.Shape._unify((
681+
_ast.Shape(len(operand_a), signed_a),
682+
_ast.Shape(len(operand_b), signed_b),
683+
))
684+
operand_a = self.extend(operand_a, signed_a, shape.width)
685+
operand_b = self.extend(operand_b, signed_b, shape.width)
686+
return (operand_a, operand_b, shape.signed)
690687

691688
def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool]:
692689
"""Emits a RHS value, returns a tuple of (value, is_signed)"""
@@ -825,19 +822,11 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool
825822
signed = False
826823
elif isinstance(value, _ast.ArrayProxy):
827824
elems = [self.emit_rhs(module_idx, elem) for elem in value.elems]
828-
width = 0
829-
signed = False
830-
for elem, elem_signed in elems:
831-
if elem_signed:
832-
if not signed:
833-
width += 1
834-
signed = True
835-
width = max(width, len(elem))
836-
elif signed:
837-
width = max(width, len(elem) + 1)
838-
else:
839-
width = max(width, len(elem))
840-
elems = tuple(self.extend(elem, elem_signed, width) for elem, elem_signed in elems)
825+
shape = _ast.Shape._unify(
826+
_ast.Shape(len(value), signed)
827+
for value, signed in elems
828+
)
829+
elems = tuple(self.extend(elem, elem_signed, shape.width) for elem, elem_signed in elems)
841830
index, _signed = self.emit_rhs(module_idx, value.index)
842831
conds = []
843832
for case_index in range(len(elems)):
@@ -855,7 +844,8 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> Tuple[_nir.Value, bool
855844
]
856845
cell = _nir.AssignmentList(module_idx, default=elems[0], assignments=assignments,
857846
src_loc=value.src_loc)
858-
result = self.netlist.add_value_cell(width, cell)
847+
result = self.netlist.add_value_cell(shape.width, cell)
848+
signed = shape.signed
859849
elif isinstance(value, _ast.Cat):
860850
nets = []
861851
for val in value.parts:

0 commit comments

Comments
 (0)