Skip to content

Commit cd6cbd7

Browse files
wanda-phiwhitequark
authored andcommitted
hdl.{_ast,_dsl}: factor out the pattern normalization logic.
1 parent 0e4c2de commit cd6cbd7

File tree

7 files changed

+69
-97
lines changed

7 files changed

+69
-97
lines changed

amaranth/hdl/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from ._ast import SyntaxError, SyntaxWarning
12
from ._ast import Shape, unsigned, signed, ShapeCastable, ShapeLike
23
from ._ast import Value, ValueCastable, ValueLike
34
from ._ast import Const, C, Mux, Cat, Array, Signal, ClockSignal, ResetSignal
45
from ._ast import Format, Print, Assert, Assume, Cover
56
from ._ast import IOValue, IOPort
6-
from ._dsl import SyntaxError, SyntaxWarning, Module
7+
from ._dsl import Module
78
from ._cd import DomainError, ClockDomain
89
from ._ir import UnusedElaboratable, Elaboratable, DriverConflict, Fragment
910
from ._ir import Instance, IOBufferInstance
@@ -14,13 +15,14 @@
1415

1516
__all__ = [
1617
# _ast
18+
"SyntaxError", "SyntaxWarning",
1719
"Shape", "unsigned", "signed", "ShapeCastable", "ShapeLike",
1820
"Value", "ValueCastable", "ValueLike",
1921
"Const", "C", "Mux", "Cat", "Array", "Signal", "ClockSignal", "ResetSignal",
2022
"Format", "Print", "Assert", "Assume", "Cover",
2123
"IOValue", "IOPort",
2224
# _dsl
23-
"SyntaxError", "SyntaxWarning", "Module",
25+
"Module",
2426
# _cd
2527
"DomainError", "ClockDomain",
2628
# _ir

amaranth/hdl/_ast.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
__all__ = [
19+
"SyntaxError", "SyntaxWarning",
1920
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
2021
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat",
2122
"Array", "ArrayProxy",
@@ -30,6 +31,14 @@
3031
]
3132

3233

34+
class SyntaxError(Exception):
35+
pass
36+
37+
38+
class SyntaxWarning(Warning):
39+
pass
40+
41+
3342
class DUID:
3443
"""Deterministic Unique IDentifier."""
3544
__next_uid = 0
@@ -426,6 +435,37 @@ def __new__(cls, *args, **kwargs):
426435
raise TypeError("ShapeLike is an abstract class and cannot be instantiated")
427436

428437

438+
def _normalize_patterns(patterns, shape, *, src_loc_at=1):
439+
new_patterns = []
440+
for pattern in patterns:
441+
orig_pattern = pattern
442+
if isinstance(pattern, str):
443+
if any(bit not in "01- \t" for bit in pattern):
444+
raise SyntaxError(f"Pattern '{pattern}' must consist of 0, 1, and - (don't "
445+
f"care) bits, and may include whitespace")
446+
pattern = "".join(pattern.split()) # remove whitespace
447+
if len(pattern) != shape.width:
448+
raise SyntaxError(f"Pattern '{orig_pattern}' must have the same width as "
449+
f"match value (which is {shape.width})")
450+
else:
451+
try:
452+
pattern = Const.cast(pattern)
453+
except TypeError as e:
454+
raise SyntaxError(f"Pattern must be a string or a constant-castable "
455+
f"expression, not {pattern!r}") from e
456+
cast_pattern = Const(pattern.value, shape)
457+
if cast_pattern.value != pattern.value:
458+
warnings.warn(f"Pattern '{orig_pattern!r}' "
459+
f"({pattern.shape().width}'{pattern.value:b}) is not "
460+
f"representable in match value shape "
461+
f"({shape!r}); comparison will never be true",
462+
SyntaxWarning, stacklevel=2 + src_loc_at)
463+
continue
464+
pattern = pattern.value
465+
new_patterns.append(pattern)
466+
return tuple(new_patterns)
467+
468+
429469
def _overridable_by_reflected(method_name):
430470
"""Allow overriding the decorated method.
431471
@@ -1248,36 +1288,12 @@ def matches(self, *patterns):
12481288
If a pattern has invalid syntax.
12491289
"""
12501290
matches = []
1251-
# This code should accept exactly the same patterns as `with m.Case(...):`.
1252-
for pattern in patterns:
1253-
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
1254-
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
1255-
"bits, and may include whitespace"
1256-
.format(pattern))
1257-
if (isinstance(pattern, str) and
1258-
len("".join(pattern.split())) != len(self)):
1259-
raise SyntaxError("Match pattern '{}' must have the same width as match value "
1260-
"(which is {})"
1261-
.format(pattern, len(self)))
1291+
for pattern in _normalize_patterns(patterns, self.shape()):
12621292
if isinstance(pattern, str):
1263-
pattern = "".join(pattern.split()) # remove whitespace
1264-
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
1265-
pattern = int(pattern.replace("-", "0"), 2)
1293+
mask = int("0" + pattern.replace("0", "1").replace("-", "0"), 2)
1294+
pattern = int("0" + pattern.replace("-", "0"), 2)
12661295
matches.append((self & mask) == pattern)
12671296
else:
1268-
try:
1269-
orig_pattern, pattern = pattern, Const.cast(pattern)
1270-
except TypeError as e:
1271-
raise SyntaxError("Match pattern must be a string or a constant-castable "
1272-
"expression, not {!r}"
1273-
.format(pattern)) from e
1274-
pattern_len = bits_for(pattern.value)
1275-
if pattern_len > len(self):
1276-
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
1277-
"(which has width {}); comparison will never be true"
1278-
.format(orig_pattern, pattern_len, pattern.value, len(self)),
1279-
SyntaxWarning, stacklevel=2)
1280-
continue
12811297
matches.append(self == pattern)
12821298
if not matches:
12831299
return Const(0)
@@ -2770,17 +2786,9 @@ def __init__(self, test, cases, *, src_loc=None, src_loc_at=0, case_src_locs={})
27702786
# Map: 2 -> "0010"; "0010" -> "0010"
27712787
new_keys = ()
27722788
key_mask = (1 << len(self.test)) - 1
2773-
for key in keys:
2774-
if isinstance(key, str):
2775-
key = "".join(key.split()) # remove whitespace
2776-
elif isinstance(key, int):
2789+
for key in _normalize_patterns(keys, self._test.shape()):
2790+
if isinstance(key, int):
27772791
key = to_binary(key & key_mask, len(self.test))
2778-
elif isinstance(key, Enum):
2779-
key = to_binary(key.value & key_mask, len(self.test))
2780-
else:
2781-
raise TypeError("Object {!r} cannot be used as a switch key"
2782-
.format(key))
2783-
assert len(key) == len(self.test)
27842792
new_keys = (*new_keys, key)
27852793
if not isinstance(stmts, Iterable):
27862794
stmts = [stmts]

amaranth/hdl/_dsl.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..utils import bits_for
1010
from .. import tracer
1111
from ._ast import *
12-
from ._ast import _StatementList, _LateBoundStatement, Property, Print
12+
from ._ast import _StatementList, _LateBoundStatement, _normalize_patterns
1313
from ._ir import *
1414
from ._cd import *
1515
from ._xfrm import *
@@ -18,14 +18,6 @@
1818
__all__ = ["SyntaxError", "SyntaxWarning", "Module"]
1919

2020

21-
class SyntaxError(Exception):
22-
pass
23-
24-
25-
class SyntaxWarning(Warning):
26-
pass
27-
28-
2921
class _ModuleBuilderProxy:
3022
def __init__(self, builder, depth):
3123
object.__setattr__(self, "_builder", builder)
@@ -344,41 +336,10 @@ def Case(self, *patterns):
344336
self._check_context("Case", context="Switch")
345337
src_loc = tracer.get_src_loc(src_loc_at=1)
346338
switch_data = self._get_ctrl("Switch")
347-
new_patterns = ()
348339
if () in switch_data["cases"]:
349340
warnings.warn("A case defined after the default case will never be active",
350341
SyntaxWarning, stacklevel=3)
351-
# This code should accept exactly the same patterns as `v.matches(...)`.
352-
for pattern in patterns:
353-
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
354-
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
355-
"bits, and may include whitespace"
356-
.format(pattern))
357-
if (isinstance(pattern, str) and
358-
len("".join(pattern.split())) != len(switch_data["test"])):
359-
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
360-
"(which is {})"
361-
.format(pattern, len(switch_data["test"])))
362-
if isinstance(pattern, str):
363-
new_patterns = (*new_patterns, pattern)
364-
else:
365-
try:
366-
orig_pattern, pattern = pattern, Const.cast(pattern)
367-
except TypeError as e:
368-
raise SyntaxError("Case pattern must be a string or a constant-castable "
369-
"expression, not {!r}"
370-
.format(pattern)) from e
371-
pattern_len = bits_for(pattern.value)
372-
if pattern.value == 0:
373-
pattern_len = 0
374-
if pattern_len > len(switch_data["test"]):
375-
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
376-
"(which has width {}); comparison will never be true"
377-
.format(orig_pattern, pattern_len, pattern.value,
378-
len(switch_data["test"])),
379-
SyntaxWarning, stacklevel=3)
380-
continue
381-
new_patterns = (*new_patterns, pattern.value)
342+
new_patterns = _normalize_patterns(patterns, switch_data["test"].shape())
382343
try:
383344
_outer_case, self._statements = self._statements, {}
384345
self._ctrl_context = None

amaranth/lib/enum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
import operator
44

5-
from ..hdl._ast import Value, ValueCastable, Shape, ShapeCastable, Const
5+
from ..hdl import Value, ValueCastable, Shape, ShapeCastable, Const, SyntaxWarning
66
from ..hdl._repr import *
77

88

tests/test_hdl_ast.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def test_matches(self):
795795
def test_matches_enum(self):
796796
s = Signal(SignedEnum)
797797
self.assertRepr(s.matches(SignedEnum.FOO), """
798-
(== (sig s) (const 2'sd-1))
798+
(== (sig s) (const 1'sd-1))
799799
""")
800800

801801
def test_matches_const_castable(self):
@@ -807,28 +807,28 @@ def test_matches_const_castable(self):
807807
def test_matches_width_wrong(self):
808808
s = Signal(4)
809809
with self.assertRaisesRegex(SyntaxError,
810-
r"^Match pattern '--' must have the same width as match value \(which is 4\)$"):
810+
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
811811
s.matches("--")
812812
with self.assertWarnsRegex(SyntaxWarning,
813-
r"^Match pattern '22' \(5'10110\) is wider than match value \(which has "
814-
r"width 4\); comparison will never be true$"):
813+
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
814+
r"\(unsigned\(4\)\); comparison will never be true$"):
815815
s.matches(0b10110)
816816
with self.assertWarnsRegex(SyntaxWarning,
817-
r"^Match pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is wider "
818-
r"than match value \(which has width 4\); comparison will never be true$"):
817+
r"^Pattern '\(cat \(const 1'd0\) \(const 4'd11\)\)' \(5'10110\) is not "
818+
r"representable in match value shape \(unsigned\(4\)\); comparison will never be true$"):
819819
s.matches(Cat(0, C(0b1011, 4)))
820820

821821
def test_matches_bits_wrong(self):
822822
s = Signal(4)
823823
with self.assertRaisesRegex(SyntaxError,
824-
r"^Match pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
824+
r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
825825
r"and may include whitespace$"):
826826
s.matches("abc")
827827

828828
def test_matches_pattern_wrong(self):
829829
s = Signal(4)
830830
with self.assertRaisesRegex(SyntaxError,
831-
r"^Match pattern must be a string or a constant-castable expression, not 1\.0$"):
831+
r"^Pattern must be a string or a constant-castable expression, not 1\.0$"):
832832
s.matches(1.0)
833833

834834
def test_hash(self):
@@ -1695,7 +1695,7 @@ def test_int_case(self):
16951695
self.assertEqual(s.cases, {("00001010",): []})
16961696

16971697
def test_int_neg_case(self):
1698-
s = Switch(Const(0, 8), {-10: []})
1698+
s = Switch(Const(0, signed(8)), {-10: []})
16991699
self.assertEqual(s.cases, {("11110110",): []})
17001700

17011701
def test_int_zero_width(self):

tests/test_hdl_dsl.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,17 +487,17 @@ class Color(Enum):
487487
dummy = Signal()
488488
with m.Switch(self.w1):
489489
with self.assertRaisesRegex(SyntaxError,
490-
r"^Case pattern '--' must have the same width as switch value \(which is 4\)$"):
490+
r"^Pattern '--' must have the same width as match value \(which is 4\)$"):
491491
with m.Case("--"):
492492
m.d.comb += dummy.eq(0)
493493
with self.assertWarnsRegex(SyntaxWarning,
494-
r"^Case pattern '22' \(5'10110\) is wider than switch value \(which has "
495-
r"width 4\); comparison will never be true$"):
494+
r"^Pattern '22' \(5'10110\) is not representable in match value shape "
495+
r"\(unsigned\(4\)\); comparison will never be true$"):
496496
with m.Case(0b10110):
497497
m.d.comb += dummy.eq(0)
498498
with self.assertWarnsRegex(SyntaxWarning,
499-
r"^Case pattern '<Color.RED: 170>' \(8'10101010\) is wider than switch value "
500-
r"\(which has width 4\); comparison will never be true$"):
499+
r"^Pattern '<Color.RED: 170>' \(8'10101010\) is not representable in "
500+
r"match value shape \(unsigned\(4\)\); comparison will never be true$"):
501501
with m.Case(Color.RED):
502502
m.d.comb += dummy.eq(0)
503503
self.assertEqual(m._statements, {})
@@ -521,7 +521,7 @@ def test_Case_bits_wrong(self):
521521
m = Module()
522522
with m.Switch(self.w1):
523523
with self.assertRaisesRegex(SyntaxError,
524-
(r"^Case pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
524+
(r"^Pattern 'abc' must consist of 0, 1, and - \(don't care\) bits, "
525525
r"and may include whitespace$")):
526526
with m.Case("abc"):
527527
pass
@@ -530,7 +530,7 @@ def test_Case_pattern_wrong(self):
530530
m = Module()
531531
with m.Switch(self.w1):
532532
with self.assertRaisesRegex(SyntaxError,
533-
r"^Case pattern must be a string or a constant-castable expression, "
533+
r"^Pattern must be a string or a constant-castable expression, "
534534
r"not 1\.0$"):
535535
with m.Case(1.0):
536536
pass

tests/test_lib_enum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55

66
from amaranth import *
7+
from amaranth.hdl import *
78
from amaranth.lib.enum import Enum, EnumMeta, Flag, IntEnum, EnumView, FlagView
89

910
from .utils import *

0 commit comments

Comments
 (0)