18
18
__all__ = [
19
19
"SyntaxError" , "SyntaxWarning" ,
20
20
"Shape" , "signed" , "unsigned" , "ShapeCastable" , "ShapeLike" ,
21
- "Value" , "Const" , "C" , "AnyConst" , "AnySeq" , "Operator" , "Mux" , "Part" , "Slice" , "Cat" , "Concat" ,
21
+ "Value" , "Const" , "C" , "AnyConst" , "AnySeq" , "Operator" , "Mux" , "Part" , "Slice" , "Cat" , "Concat" , "SwitchValue" ,
22
22
"Array" , "ArrayProxy" ,
23
23
"Signal" , "ClockSignal" , "ResetSignal" ,
24
24
"ValueCastable" , "ValueLike" ,
@@ -1892,6 +1892,60 @@ def __repr__(self):
1892
1892
return "(cat {})" .format (" " .join (map (repr , self .parts )))
1893
1893
1894
1894
1895
+ @final
1896
+ class SwitchValue (Value ):
1897
+ def __init__ (self , test , cases , * , src_loc = None , src_loc_at = 0 ):
1898
+ if src_loc is None :
1899
+ super ().__init__ (src_loc_at = src_loc_at )
1900
+ else :
1901
+ self .src_loc = src_loc
1902
+ self ._test = Value .cast (test )
1903
+ new_cases = []
1904
+ for patterns , value in cases :
1905
+ if patterns is not None :
1906
+ if not isinstance (patterns , tuple ):
1907
+ patterns = (patterns ,)
1908
+ new_patterns = ()
1909
+ key_mask = (1 << len (self .test )) - 1
1910
+ for key in _normalize_patterns (patterns , self ._test .shape ()):
1911
+ if isinstance (key , int ):
1912
+ key = to_binary (key & key_mask , len (self .test ))
1913
+ new_patterns = (* new_patterns , key )
1914
+ else :
1915
+ new_patterns = None
1916
+ new_cases .append ((new_patterns , Value .cast (value )))
1917
+ self ._cases = tuple (new_cases )
1918
+
1919
+ @property
1920
+ def test (self ):
1921
+ return self ._test
1922
+
1923
+ @property
1924
+ def cases (self ):
1925
+ return self ._cases
1926
+
1927
+ def shape (self ):
1928
+ return Shape ._unify (value .shape () for _patterns , value in self ._cases )
1929
+
1930
+ def _lhs_signals (self ):
1931
+ return union ((value ._lhs_signals () for _patterns , value in self .cases ), start = SignalSet ())
1932
+
1933
+ def _rhs_signals (self ):
1934
+ signals = union ((value ._rhs_signals () for _patterns , value in self .cases ), start = SignalSet ())
1935
+ return self .test ._rhs_signals () | signals
1936
+
1937
+ def __repr__ (self ):
1938
+ def case_repr (patterns , value ):
1939
+ if patterns is None :
1940
+ return f"(default { value !r} )"
1941
+ elif len (patterns ) == 1 :
1942
+ return f"(case { patterns [0 ]} { value !r} )"
1943
+ else :
1944
+ return "(case ({}) {!r})" .format (" " .join (patterns ), value )
1945
+ case_reprs = (case_repr (patterns , value ) for patterns , value in self .cases )
1946
+ return "(switch-value {!r} {})" .format (self .test , " " .join (case_reprs ))
1947
+
1948
+
1895
1949
class _SignalMeta (ABCMeta ):
1896
1950
def __call__ (cls , shape = None , src_loc_at = 0 , ** kwargs ):
1897
1951
signal = super ().__call__ (shape , ** kwargs , src_loc_at = src_loc_at + 1 )
@@ -2356,10 +2410,17 @@ def __repr__(self):
2356
2410
", " .join (map (repr , self ._inner )))
2357
2411
2358
2412
2413
+ def _proxy_value (name ):
2414
+ @functools .wraps (getattr (Value , name ))
2415
+ def inner (self , * args , ** kwargs ):
2416
+ return getattr (Value .cast (self ), name )(* args , ** kwargs )
2417
+ return inner
2418
+
2419
+
2359
2420
@final
2360
- class ArrayProxy (Value ):
2421
+ class ArrayProxy (ValueCastable ):
2361
2422
def __init__ (self , elems , index , * , src_loc_at = 0 ):
2362
- super (). __init__ ( src_loc_at = 1 + src_loc_at )
2423
+ self . src_loc = tracer . get_src_loc ( 1 + src_loc_at )
2363
2424
self ._elems = elems
2364
2425
self ._index = Value .cast (index )
2365
2426
@@ -2385,19 +2446,73 @@ def shape(self):
2385
2446
# elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
2386
2447
return Shape ._unify (elem .shape () for elem in self ._iter_as_values ())
2387
2448
2388
- def _lhs_signals (self ):
2389
- signals = union ((elem ._lhs_signals () for elem in self ._iter_as_values ()),
2390
- start = SignalSet ())
2391
- return signals
2449
+ def as_value (self ):
2450
+ return SwitchValue (
2451
+ self ._index ,
2452
+ (
2453
+ (index , value )
2454
+ for index , value in enumerate (self ._elems )
2455
+ if index in range (1 << len (self ._index ))
2456
+ ),
2457
+ src_loc = self .src_loc ,
2458
+ )
2392
2459
2393
- def _rhs_signals (self ):
2394
- signals = union ((elem ._rhs_signals () for elem in self ._iter_as_values ()),
2395
- start = SignalSet ())
2396
- return self .index ._rhs_signals () | signals
2460
+ def eq (self , value , * , src_loc_at = 0 ):
2461
+ return self .as_value ().eq (value , src_loc_at = 1 + src_loc_at )
2397
2462
2398
2463
def __repr__ (self ):
2399
2464
return "(proxy (array [{}]) {!r})" .format (", " .join (map (repr , self .elems )), self .index )
2400
2465
2466
+ as_signed = _proxy_value ("as_signed" )
2467
+ as_unsigned = _proxy_value ("as_unsigned" )
2468
+ __len__ = _proxy_value ("__len__" )
2469
+ __bool__ = _proxy_value ("__bool__" )
2470
+ bool = _proxy_value ("bool" )
2471
+ __pos__ = _proxy_value ("__pos__" )
2472
+ __neg__ = _proxy_value ("__neg__" )
2473
+ __add__ = _proxy_value ("__add__" )
2474
+ __radd__ = _proxy_value ("__radd__" )
2475
+ __sub__ = _proxy_value ("__sub__" )
2476
+ __rsub__ = _proxy_value ("__rsub__" )
2477
+ __mul__ = _proxy_value ("__mul__" )
2478
+ __rmul__ = _proxy_value ("__rmul__" )
2479
+ __floordiv__ = _proxy_value ("__floordiv__" )
2480
+ __rfloordiv__ = _proxy_value ("__rfloordiv__" )
2481
+ __mod__ = _proxy_value ("__mod__" )
2482
+ __rmod__ = _proxy_value ("__rmod__" )
2483
+ __eq__ = _proxy_value ("__eq__" )
2484
+ __ne__ = _proxy_value ("__ne__" )
2485
+ __lt__ = _proxy_value ("__lt__" )
2486
+ __le__ = _proxy_value ("__le__" )
2487
+ __gt__ = _proxy_value ("__gt__" )
2488
+ __ge__ = _proxy_value ("__ge__" )
2489
+ __abs__ = _proxy_value ("__abs__" )
2490
+ __invert__ = _proxy_value ("__invert__" )
2491
+ __and__ = _proxy_value ("__and__" )
2492
+ __rand__ = _proxy_value ("__rand__" )
2493
+ __or__ = _proxy_value ("__or__" )
2494
+ __ror__ = _proxy_value ("__ror__" )
2495
+ __xor__ = _proxy_value ("__xor__" )
2496
+ __rxor__ = _proxy_value ("__rxor__" )
2497
+ any = _proxy_value ("any" )
2498
+ all = _proxy_value ("all" )
2499
+ xor = _proxy_value ("xor" )
2500
+ implies = _proxy_value ("implies" )
2501
+ __lshift__ = _proxy_value ("__lshift__" )
2502
+ __rlshift__ = _proxy_value ("__rlshift__" )
2503
+ __rshift__ = _proxy_value ("__rshift__" )
2504
+ __rrshift__ = _proxy_value ("__rrshift__" )
2505
+ shift_left = _proxy_value ("shift_left" )
2506
+ shift_right = _proxy_value ("shift_right" )
2507
+ rotate_left = _proxy_value ("rotate_left" )
2508
+ rotate_right = _proxy_value ("rotate_right" )
2509
+ __contains__ = _proxy_value ("__contains__" )
2510
+ bit_select = _proxy_value ("bit_select" )
2511
+ word_select = _proxy_value ("word_select" )
2512
+ replicate = _proxy_value ("replicate" )
2513
+ matches = _proxy_value ("matches" )
2514
+ __format__ = _proxy_value ("__format__" )
2515
+
2401
2516
2402
2517
@final
2403
2518
class Initial (Value ):
@@ -2772,7 +2887,7 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
2772
2887
self .src_loc = src_loc
2773
2888
2774
2889
self ._test = Value .cast (test )
2775
- self . _cases = []
2890
+ new_cases = []
2776
2891
for patterns , stmts , case_src_loc in cases :
2777
2892
if patterns is not None :
2778
2893
# Map: key -> (key,); (key...) -> (key...)
@@ -2787,10 +2902,8 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0):
2787
2902
new_patterns = (* new_patterns , key )
2788
2903
else :
2789
2904
new_patterns = None
2790
- if not isinstance (stmts , Iterable ):
2791
- stmts = [stmts ]
2792
- self ._cases .append ((new_patterns , Statement .cast (stmts ), case_src_loc ))
2793
- self ._cases = tuple (self ._cases )
2905
+ new_cases .append ((new_patterns , Statement .cast (stmts ), case_src_loc ))
2906
+ self ._cases = tuple (new_cases )
2794
2907
2795
2908
@property
2796
2909
def test (self ):
@@ -2816,7 +2929,7 @@ def case_repr(patterns, stmts):
2816
2929
return f"(case { patterns [0 ]} { stmts_repr } )"
2817
2930
else :
2818
2931
return "(case ({}) {})" .format (" " .join (patterns ), stmts_repr )
2819
- case_reprs = [ case_repr (patterns , stmts ) for patterns , stmts , _src_loc in self .cases ]
2932
+ case_reprs = ( case_repr (patterns , stmts ) for patterns , stmts , _src_loc in self .cases )
2820
2933
return "(switch {!r} {})" .format (self .test , " " .join (case_reprs ))
2821
2934
2822
2935
0 commit comments