@@ -159,6 +159,28 @@ def __eq__(self, other):
159
159
return (isinstance (other , Shape ) and
160
160
self .width == other .width and self .signed == other .signed )
161
161
162
+ @staticmethod
163
+ def _unify (shapes ):
164
+ """Returns the minimal shape that contains all shapes from the list.
165
+
166
+ If no shapes passed in, returns unsigned(0).
167
+ """
168
+ unsigned_width = signed_width = 0
169
+ has_signed = False
170
+ for shape in shapes :
171
+ assert isinstance (shape , Shape )
172
+ if shape .signed :
173
+ has_signed = True
174
+ signed_width = max (signed_width , shape .width )
175
+ else :
176
+ unsigned_width = max (unsigned_width , shape .width )
177
+ # If all shapes unsigned, simply take max.
178
+ if not has_signed :
179
+ return unsigned (unsigned_width )
180
+ # Otherwise, result is signed. All unsigned inputs, if any,
181
+ # need to be converted to signed by adding a zero bit.
182
+ return signed (max (signed_width , unsigned_width + 1 ))
183
+
162
184
163
185
def unsigned (width ):
164
186
"""Returns :py:`Shape(width, signed=False)`."""
@@ -1524,20 +1546,6 @@ def operands(self):
1524
1546
return self ._operands
1525
1547
1526
1548
def shape (self ):
1527
- def _bitwise_binary_shape (a_shape , b_shape ):
1528
- if not a_shape .signed and not b_shape .signed :
1529
- # both operands unsigned
1530
- return unsigned (max (a_shape .width , b_shape .width ))
1531
- elif a_shape .signed and b_shape .signed :
1532
- # both operands signed
1533
- return signed (max (a_shape .width , b_shape .width ))
1534
- elif not a_shape .signed and b_shape .signed :
1535
- # first operand unsigned (add sign bit), second operand signed
1536
- return signed (max (a_shape .width + 1 , b_shape .width ))
1537
- else :
1538
- # first signed, second operand unsigned (add sign bit)
1539
- return signed (max (a_shape .width , b_shape .width + 1 ))
1540
-
1541
1549
op_shapes = list (map (lambda x : x .shape (), self .operands ))
1542
1550
if len (op_shapes ) == 1 :
1543
1551
a_shape , = op_shapes
@@ -1554,10 +1562,10 @@ def _bitwise_binary_shape(a_shape, b_shape):
1554
1562
elif len (op_shapes ) == 2 :
1555
1563
a_shape , b_shape = op_shapes
1556
1564
if self .operator == "+" :
1557
- o_shape = _bitwise_binary_shape ( * op_shapes )
1565
+ o_shape = Shape . _unify ( op_shapes )
1558
1566
return Shape (o_shape .width + 1 , o_shape .signed )
1559
1567
if self .operator == "-" :
1560
- o_shape = _bitwise_binary_shape ( * op_shapes )
1568
+ o_shape = Shape . _unify ( op_shapes )
1561
1569
return Shape (o_shape .width + 1 , True )
1562
1570
if self .operator == "*" :
1563
1571
return Shape (a_shape .width + b_shape .width , a_shape .signed or b_shape .signed )
@@ -1568,7 +1576,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
1568
1576
if self .operator in ("<" , "<=" , "==" , "!=" , ">" , ">=" ):
1569
1577
return Shape (1 , False )
1570
1578
if self .operator in ("&" , "|" , "^" ):
1571
- return _bitwise_binary_shape ( * op_shapes )
1579
+ return Shape . _unify ( op_shapes )
1572
1580
if self .operator == "<<" :
1573
1581
assert not b_shape .signed
1574
1582
return Shape (a_shape .width + 2 ** b_shape .width - 1 , a_shape .signed )
@@ -1578,7 +1586,7 @@ def _bitwise_binary_shape(a_shape, b_shape):
1578
1586
elif len (op_shapes ) == 3 :
1579
1587
if self .operator == "m" :
1580
1588
s_shape , a_shape , b_shape = op_shapes
1581
- return _bitwise_binary_shape ( a_shape , b_shape )
1589
+ return Shape . _unify (( a_shape , b_shape ) )
1582
1590
raise NotImplementedError # :nocov:
1583
1591
1584
1592
def _lhs_signals (self ):
@@ -2254,27 +2262,9 @@ def _iter_as_values(self):
2254
2262
return (Value .cast (elem ) for elem in self .elems )
2255
2263
2256
2264
def shape (self ):
2257
- unsigned_width = signed_width = 0
2258
- has_unsigned = has_signed = False
2259
- for elem_shape in (elem .shape () for elem in self ._iter_as_values ()):
2260
- if elem_shape .signed :
2261
- has_signed = True
2262
- signed_width = max (signed_width , elem_shape .width )
2263
- else :
2264
- has_unsigned = True
2265
- unsigned_width = max (unsigned_width , elem_shape .width )
2266
2265
# The shape of the proxy must be such that it preserves the mathematical value of the array
2267
2266
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
2268
- # To ensure this holds, if the array contains both signed and unsigned values, make sure
2269
- # that every unsigned value is zero-extended by at least one bit.
2270
- if has_signed and has_unsigned and unsigned_width >= signed_width :
2271
- # Array contains both signed and unsigned values, and at least one of the unsigned
2272
- # values won't be zero-extended otherwise.
2273
- return signed (unsigned_width + 1 )
2274
- else :
2275
- # Array contains values of the same signedness, or else all of the unsigned values
2276
- # are zero-extended.
2277
- return Shape (max (unsigned_width , signed_width ), has_signed )
2267
+ return Shape ._unify (elem .shape () for elem in self ._iter_as_values ())
2278
2268
2279
2269
def _lhs_signals (self ):
2280
2270
signals = union ((elem ._lhs_signals () for elem in self ._iter_as_values ()),
0 commit comments