Skip to content

Commit ea94c9c

Browse files
authored
hdl.rec: proxy operators correctly.
Commit abbebf8 used __getattr__ to proxy Value methods called on Record. However, that did not proxy operators like __add__ because Python looks up the special operator methods directly on the class and does not run __getattr__ if they are missing. Instead of using __getattr__, explicitly enumerate and wrap every Value method that should be proxied. This also ensures backwards compatibility if more methods are added to Value later. Fixes #533.
1 parent ebbdac9 commit ea94c9c

File tree

2 files changed

+138
-11
lines changed

2 files changed

+138
-11
lines changed

nmigen/hdl/rec.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,7 @@ def concat(a, b):
143143
src_loc_at=1 + src_loc_at)
144144

145145
def __getattr__(self, name):
146-
# must check `getattr` before `self` - we need to hit Value methods before fields
147-
try:
148-
value_attr = getattr(Value, name)
149-
if callable(value_attr):
150-
@wraps(value_attr)
151-
def _wrapper(*args, **kwargs):
152-
return value_attr(self, *args, **kwargs)
153-
return _wrapper
154-
return value_attr
155-
except AttributeError:
156-
return self[name]
146+
return self[name]
157147

158148
def __getitem__(self, item):
159149
if isinstance(item, str):
@@ -257,3 +247,29 @@ def rec_name(record):
257247
stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))]
258248

259249
return stmts
250+
251+
def _valueproxy(name):
252+
value_func = getattr(Value, name)
253+
@wraps(value_func)
254+
def _wrapper(self, *args, **kwargs):
255+
return value_func(Value.cast(self), *args, **kwargs)
256+
return _wrapper
257+
258+
for name in [
259+
"__bool__",
260+
"__invert__", "__neg__",
261+
"__add__", "__radd__", "__sub__", "__rsub__",
262+
"__mul__", "__rmul__",
263+
"__mod__", "__rmod__", "__floordiv__", "__rfloordiv__",
264+
"__lshift__", "__rlshift__", "__rshift__", "__rrshift__",
265+
"__and__", "__rand__", "__xor__", "__rxor__", "__or__", "__ror__",
266+
"__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__",
267+
"__abs__", "__len__",
268+
"as_unsigned", "as_signed", "bool", "any", "all", "xor", "implies",
269+
"bit_select", "word_select", "matches",
270+
"shift_left", "shift_right", "rotate_left", "rotate_right", "eq"
271+
]:
272+
setattr(Record, name, _valueproxy(name))
273+
274+
del _valueproxy
275+
del name

tests/test_hdl_rec.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,117 @@ def test_enum_decoder(self):
211211
r1 = Record([("a", UnsignedEnum)])
212212
self.assertEqual(r1.a.decoder(UnsignedEnum.FOO), "FOO/1")
213213

214+
def test_operators(self):
215+
r1 = Record([("a", 1)])
216+
s1 = Signal()
217+
218+
# __bool__
219+
with self.assertRaisesRegex(TypeError,
220+
r"^Attempted to convert nMigen value to Python boolean$"):
221+
not r1
222+
223+
# __invert__, __neg__
224+
self.assertEqual(repr(~r1), "(~ (cat (sig r1__a)))")
225+
self.assertEqual(repr(-r1), "(- (cat (sig r1__a)))")
226+
227+
# __add__, __radd__, __sub__, __rsub__
228+
self.assertEqual(repr(r1 + 1), "(+ (cat (sig r1__a)) (const 1'd1))")
229+
self.assertEqual(repr(r1 + s1), "(+ (cat (sig r1__a)) (sig s1))")
230+
self.assertEqual(repr(1 + r1), "(+ (const 1'd1) (cat (sig r1__a)))")
231+
self.assertEqual(repr(s1 + r1), "(+ (sig s1) (cat (sig r1__a)))")
232+
self.assertEqual(repr(r1 - 1), "(- (cat (sig r1__a)) (const 1'd1))")
233+
self.assertEqual(repr(r1 - s1), "(- (cat (sig r1__a)) (sig s1))")
234+
self.assertEqual(repr(1 - r1), "(- (const 1'd1) (cat (sig r1__a)))")
235+
self.assertEqual(repr(s1 - r1), "(- (sig s1) (cat (sig r1__a)))")
236+
237+
# __mul__, __rmul__
238+
self.assertEqual(repr(r1 * 1), "(* (cat (sig r1__a)) (const 1'd1))")
239+
self.assertEqual(repr(r1 * s1), "(* (cat (sig r1__a)) (sig s1))")
240+
self.assertEqual(repr(1 * r1), "(* (const 1'd1) (cat (sig r1__a)))")
241+
self.assertEqual(repr(s1 * r1), "(* (sig s1) (cat (sig r1__a)))")
242+
243+
# __mod__, __rmod__, __floordiv__, __rfloordiv__
244+
self.assertEqual(repr(r1 % 1), "(% (cat (sig r1__a)) (const 1'd1))")
245+
self.assertEqual(repr(r1 % s1), "(% (cat (sig r1__a)) (sig s1))")
246+
self.assertEqual(repr(1 % r1), "(% (const 1'd1) (cat (sig r1__a)))")
247+
self.assertEqual(repr(s1 % r1), "(% (sig s1) (cat (sig r1__a)))")
248+
self.assertEqual(repr(r1 // 1), "(// (cat (sig r1__a)) (const 1'd1))")
249+
self.assertEqual(repr(r1 // s1), "(// (cat (sig r1__a)) (sig s1))")
250+
self.assertEqual(repr(1 // r1), "(// (const 1'd1) (cat (sig r1__a)))")
251+
self.assertEqual(repr(s1 // r1), "(// (sig s1) (cat (sig r1__a)))")
252+
253+
# __lshift__, __rlshift__, __rshift__, __rrshift__
254+
self.assertEqual(repr(r1 >> 1), "(>> (cat (sig r1__a)) (const 1'd1))")
255+
self.assertEqual(repr(r1 >> s1), "(>> (cat (sig r1__a)) (sig s1))")
256+
self.assertEqual(repr(1 >> r1), "(>> (const 1'd1) (cat (sig r1__a)))")
257+
self.assertEqual(repr(s1 >> r1), "(>> (sig s1) (cat (sig r1__a)))")
258+
self.assertEqual(repr(r1 << 1), "(<< (cat (sig r1__a)) (const 1'd1))")
259+
self.assertEqual(repr(r1 << s1), "(<< (cat (sig r1__a)) (sig s1))")
260+
self.assertEqual(repr(1 << r1), "(<< (const 1'd1) (cat (sig r1__a)))")
261+
self.assertEqual(repr(s1 << r1), "(<< (sig s1) (cat (sig r1__a)))")
262+
263+
# __and__, __rand__, __xor__, __rxor__, __or__, __ror__
264+
self.assertEqual(repr(r1 & 1), "(& (cat (sig r1__a)) (const 1'd1))")
265+
self.assertEqual(repr(r1 & s1), "(& (cat (sig r1__a)) (sig s1))")
266+
self.assertEqual(repr(1 & r1), "(& (const 1'd1) (cat (sig r1__a)))")
267+
self.assertEqual(repr(s1 & r1), "(& (sig s1) (cat (sig r1__a)))")
268+
self.assertEqual(repr(r1 ^ 1), "(^ (cat (sig r1__a)) (const 1'd1))")
269+
self.assertEqual(repr(r1 ^ s1), "(^ (cat (sig r1__a)) (sig s1))")
270+
self.assertEqual(repr(1 ^ r1), "(^ (const 1'd1) (cat (sig r1__a)))")
271+
self.assertEqual(repr(s1 ^ r1), "(^ (sig s1) (cat (sig r1__a)))")
272+
self.assertEqual(repr(r1 | 1), "(| (cat (sig r1__a)) (const 1'd1))")
273+
self.assertEqual(repr(r1 | s1), "(| (cat (sig r1__a)) (sig s1))")
274+
self.assertEqual(repr(1 | r1), "(| (const 1'd1) (cat (sig r1__a)))")
275+
self.assertEqual(repr(s1 | r1), "(| (sig s1) (cat (sig r1__a)))")
276+
277+
# __eq__, __ne__, __lt__, __le__, __gt__, __ge__
278+
self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))")
279+
self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))")
280+
self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))")
281+
self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))")
282+
self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))")
283+
self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))")
284+
self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))")
285+
self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))")
286+
self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))")
287+
self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))")
288+
self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))")
289+
self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))")
290+
self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))")
291+
self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))")
292+
self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))")
293+
self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))")
294+
self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))")
295+
self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))")
296+
297+
# __abs__, __len__
298+
self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")
299+
self.assertEqual(len(r1), 1)
300+
301+
# as_unsigned, as_signed, bool, any, all, xor, implies
302+
self.assertEqual(repr(r1.as_unsigned()), "(u (cat (sig r1__a)))")
303+
self.assertEqual(repr(r1.as_signed()), "(s (cat (sig r1__a)))")
304+
self.assertEqual(repr(r1.bool()), "(b (cat (sig r1__a)))")
305+
self.assertEqual(repr(r1.any()), "(r| (cat (sig r1__a)))")
306+
self.assertEqual(repr(r1.all()), "(r& (cat (sig r1__a)))")
307+
self.assertEqual(repr(r1.xor()), "(r^ (cat (sig r1__a)))")
308+
self.assertEqual(repr(r1.implies(1)), "(| (~ (cat (sig r1__a))) (const 1'd1))")
309+
self.assertEqual(repr(r1.implies(s1)), "(| (~ (cat (sig r1__a))) (sig s1))")
310+
311+
# bit_select, word_select, matches,
312+
self.assertEqual(repr(r1.bit_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)")
313+
self.assertEqual(repr(r1.word_select(0, 1)), "(slice (cat (sig r1__a)) 0:1)")
314+
self.assertEqual(repr(r1.matches("1")),
315+
"(== (& (cat (sig r1__a)) (const 1'd1)) (const 1'd1))")
316+
317+
# shift_left, shift_right, rotate_left, rotate_right, eq
318+
self.assertEqual(repr(r1.shift_left(1)), "(cat (const 1'd0) (cat (sig r1__a)))")
319+
self.assertEqual(repr(r1.shift_right(1)), "(slice (cat (sig r1__a)) 1:1)")
320+
self.assertEqual(repr(r1.rotate_left(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
321+
self.assertEqual(repr(r1.rotate_right(1)), "(cat (slice (cat (sig r1__a)) 0:1) (slice (cat (sig r1__a)) 0:0))")
322+
self.assertEqual(repr(r1.eq(1)), "(eq (cat (sig r1__a)) (const 1'd1))")
323+
self.assertEqual(repr(r1.eq(s1)), "(eq (cat (sig r1__a)) (sig s1))")
324+
214325

215326
class ConnectTestCase(FHDLTestCase):
216327
def setUp_flat(self):

0 commit comments

Comments
 (0)