Skip to content

Commit 8796013

Browse files
zypwhitequark
authored andcommitted
ast: allow overriding Value operators.
1 parent 1c3227d commit 8796013

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

amaranth/hdl/ast.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,25 @@ def signed(width):
156156
return Shape(width, signed=True)
157157

158158

159+
def _overridable_by_reflected(method_name):
160+
"""Allow overriding the decorated method.
161+
162+
Allows :class:`ValueCastable` to override the decorated method by implementing
163+
a reflected method named ``method_name``. Intended for operators, but
164+
also usable for other methods that have a reflected counterpart.
165+
"""
166+
def decorator(f):
167+
@functools.wraps(f)
168+
def wrapper(self, other):
169+
if isinstance(other, ValueCastable) and hasattr(other, method_name):
170+
res = getattr(other, method_name)(self)
171+
if res is not NotImplemented:
172+
return res
173+
return f(self, other)
174+
return wrapper
175+
return decorator
176+
177+
159178
class Value(metaclass=ABCMeta):
160179
@staticmethod
161180
def cast(obj):
@@ -195,26 +214,31 @@ def __invert__(self):
195214
def __neg__(self):
196215
return Operator("-", [self])
197216

217+
@_overridable_by_reflected("__radd__")
198218
def __add__(self, other):
199-
return Operator("+", [self, other])
219+
return Operator("+", [self, other], src_loc_at=1)
200220
def __radd__(self, other):
201221
return Operator("+", [other, self])
222+
@_overridable_by_reflected("__rsub__")
202223
def __sub__(self, other):
203-
return Operator("-", [self, other])
224+
return Operator("-", [self, other], src_loc_at=1)
204225
def __rsub__(self, other):
205226
return Operator("-", [other, self])
206227

228+
@_overridable_by_reflected("__rmul__")
207229
def __mul__(self, other):
208-
return Operator("*", [self, other])
230+
return Operator("*", [self, other], src_loc_at=1)
209231
def __rmul__(self, other):
210232
return Operator("*", [other, self])
211233

234+
@_overridable_by_reflected("__rmod__")
212235
def __mod__(self, other):
213-
return Operator("%", [self, other])
236+
return Operator("%", [self, other], src_loc_at=1)
214237
def __rmod__(self, other):
215238
return Operator("%", [other, self])
239+
@_overridable_by_reflected("__rfloordiv__")
216240
def __floordiv__(self, other):
217-
return Operator("//", [self, other])
241+
return Operator("//", [self, other], src_loc_at=1)
218242
def __rfloordiv__(self, other):
219243
return Operator("//", [other, self])
220244

@@ -224,46 +248,57 @@ def __check_shamt(self):
224248
# by a signed value to make sure the shift amount can always be interpreted as
225249
# an unsigned value.
226250
raise TypeError("Shift amount must be unsigned")
251+
@_overridable_by_reflected("__rlshift__")
227252
def __lshift__(self, other):
228253
other = Value.cast(other)
229254
other.__check_shamt()
230-
return Operator("<<", [self, other])
255+
return Operator("<<", [self, other], src_loc_at=1)
231256
def __rlshift__(self, other):
232257
self.__check_shamt()
233258
return Operator("<<", [other, self])
259+
@_overridable_by_reflected("__rrshift__")
234260
def __rshift__(self, other):
235261
other = Value.cast(other)
236262
other.__check_shamt()
237-
return Operator(">>", [self, other])
263+
return Operator(">>", [self, other], src_loc_at=1)
238264
def __rrshift__(self, other):
239265
self.__check_shamt()
240266
return Operator(">>", [other, self])
241267

268+
@_overridable_by_reflected("__rand__")
242269
def __and__(self, other):
243-
return Operator("&", [self, other])
270+
return Operator("&", [self, other], src_loc_at=1)
244271
def __rand__(self, other):
245272
return Operator("&", [other, self])
273+
@_overridable_by_reflected("__rxor__")
246274
def __xor__(self, other):
247-
return Operator("^", [self, other])
275+
return Operator("^", [self, other], src_loc_at=1)
248276
def __rxor__(self, other):
249277
return Operator("^", [other, self])
278+
@_overridable_by_reflected("__ror__")
250279
def __or__(self, other):
251-
return Operator("|", [self, other])
280+
return Operator("|", [self, other], src_loc_at=1)
252281
def __ror__(self, other):
253282
return Operator("|", [other, self])
254283

284+
@_overridable_by_reflected("__eq__")
255285
def __eq__(self, other):
256-
return Operator("==", [self, other])
286+
return Operator("==", [self, other], src_loc_at=1)
287+
@_overridable_by_reflected("__ne__")
257288
def __ne__(self, other):
258-
return Operator("!=", [self, other])
289+
return Operator("!=", [self, other], src_loc_at=1)
290+
@_overridable_by_reflected("__gt__")
259291
def __lt__(self, other):
260-
return Operator("<", [self, other])
292+
return Operator("<", [self, other], src_loc_at=1)
293+
@_overridable_by_reflected("__ge__")
261294
def __le__(self, other):
262-
return Operator("<=", [self, other])
295+
return Operator("<=", [self, other], src_loc_at=1)
296+
@_overridable_by_reflected("__lt__")
263297
def __gt__(self, other):
264-
return Operator(">", [self, other])
298+
return Operator(">", [self, other], src_loc_at=1)
299+
@_overridable_by_reflected("__le__")
265300
def __ge__(self, other):
266-
return Operator(">=", [self, other])
301+
return Operator(">=", [self, other], src_loc_at=1)
267302

268303
def __abs__(self):
269304
if self.shape().signed:

docs/changes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Implemented RFCs
5555
.. _RFC 19: https://amaranth-lang.org/rfcs/0019-remove-scheduler.html
5656
.. _RFC 20: https://amaranth-lang.org/rfcs/0020-deprecate-non-fwft-fifos.html
5757
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
58+
.. _RFC 28: https://amaranth-lang.org/rfcs/0028-override-value-operators.html
5859

5960

6061
* `RFC 1`_: Aggregate data structure library
@@ -71,6 +72,7 @@ Implemented RFCs
7172
* `RFC 15`_: Lifting shape-castable objects
7273
* `RFC 20`_: Deprecate non-FWFT FIFOs
7374
* `RFC 22`_: Define ``ValueCastable.shape()``
75+
* `RFC 28`_: Allow overriding ``Value`` operators
7476

7577

7678
Language changes

tests/test_hdl_rec.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,22 +285,22 @@ def test_operators(self):
285285
# __eq__, __ne__, __lt__, __le__, __gt__, __ge__
286286
self.assertEqual(repr(r1 == 1), "(== (cat (sig r1__a)) (const 1'd1))")
287287
self.assertEqual(repr(r1 == s1), "(== (cat (sig r1__a)) (sig s1))")
288-
self.assertEqual(repr(s1 == r1), "(== (sig s1) (cat (sig r1__a)))")
288+
self.assertEqual(repr(s1 == r1), "(== (cat (sig r1__a)) (sig s1))")
289289
self.assertEqual(repr(r1 != 1), "(!= (cat (sig r1__a)) (const 1'd1))")
290290
self.assertEqual(repr(r1 != s1), "(!= (cat (sig r1__a)) (sig s1))")
291-
self.assertEqual(repr(s1 != r1), "(!= (sig s1) (cat (sig r1__a)))")
291+
self.assertEqual(repr(s1 != r1), "(!= (cat (sig r1__a)) (sig s1))")
292292
self.assertEqual(repr(r1 < 1), "(< (cat (sig r1__a)) (const 1'd1))")
293293
self.assertEqual(repr(r1 < s1), "(< (cat (sig r1__a)) (sig s1))")
294-
self.assertEqual(repr(s1 < r1), "(< (sig s1) (cat (sig r1__a)))")
294+
self.assertEqual(repr(s1 < r1), "(> (cat (sig r1__a)) (sig s1))")
295295
self.assertEqual(repr(r1 <= 1), "(<= (cat (sig r1__a)) (const 1'd1))")
296296
self.assertEqual(repr(r1 <= s1), "(<= (cat (sig r1__a)) (sig s1))")
297-
self.assertEqual(repr(s1 <= r1), "(<= (sig s1) (cat (sig r1__a)))")
297+
self.assertEqual(repr(s1 <= r1), "(>= (cat (sig r1__a)) (sig s1))")
298298
self.assertEqual(repr(r1 > 1), "(> (cat (sig r1__a)) (const 1'd1))")
299299
self.assertEqual(repr(r1 > s1), "(> (cat (sig r1__a)) (sig s1))")
300-
self.assertEqual(repr(s1 > r1), "(> (sig s1) (cat (sig r1__a)))")
300+
self.assertEqual(repr(s1 > r1), "(< (cat (sig r1__a)) (sig s1))")
301301
self.assertEqual(repr(r1 >= 1), "(>= (cat (sig r1__a)) (const 1'd1))")
302302
self.assertEqual(repr(r1 >= s1), "(>= (cat (sig r1__a)) (sig s1))")
303-
self.assertEqual(repr(s1 >= r1), "(>= (sig s1) (cat (sig r1__a)))")
303+
self.assertEqual(repr(s1 >= r1), "(<= (cat (sig r1__a)) (sig s1))")
304304

305305
# __abs__, __len__
306306
self.assertEqual(repr(abs(r1)), "(cat (sig r1__a))")

0 commit comments

Comments
 (0)