From f2fb7f470ef7926c7f70c268b43fe29647337fbd Mon Sep 17 00:00:00 2001 From: Catherine Date: Thu, 27 Jun 2024 13:09:13 +0000 Subject: [PATCH 1/2] lib.data: raise `ValueError` if initializer refers to nonexistent key. Previously, `KeyError` was raised, which should not be escaping the implementation (`Layout.const` is not a retrieval method). --- amaranth/lib/data.py | 10 +++++++--- tests/test_lib_data.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 8c77cc2cf..6f157e922 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -219,12 +219,16 @@ def const(self, init): elif isinstance(init, Sequence): iterator = enumerate(init) else: - raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}" - .format(init)) + raise TypeError(f"Layout constant initializer must be a mapping or a sequence, not " + f"{init!r}") int_value = 0 for key, key_value in iterator: - field = self[key] + try: + field = self[key] + except KeyError: + raise ValueError(f"Layout constant initializer refers to key {key!r}, which is not " + f"a part of the layout") cast_field_shape = Shape.cast(field.shape) if isinstance(field.shape, ShapeCastable): key_value = hdl.Const.cast(hdl.Const(key_value, field.shape)) diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index ceafd638b..4ec80b352 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -477,6 +477,10 @@ def test_const_wrong(self): r"^Layout constant initializer must be a mapping or a sequence, not " r"<.+?object.+?>$"): sl.const(object()) + with self.assertRaisesRegex(ValueError, + r"^Layout constant initializer refers to key 'g', which is not a part " + r"of the layout$"): + sl.const({"g": 1}) sl2 = data.StructLayout({"f": unsigned(2)}) with self.assertRaisesRegex(ValueError, r"^Const layout StructLayout.* differs from shape layout StructLayout.*$"): From 0f9f398702d13fce9505831cf14acaf351981bbf Mon Sep 17 00:00:00 2001 From: Catherine Date: Thu, 27 Jun 2024 13:12:32 +0000 Subject: [PATCH 2/2] lib.data: make constants comparable with compatible initializers. Before this commit, `data.Const` was only comparable to `data.Const` or `data.View`. After this commit, in addition, it is also comparable to `dict` or `list` provided that such a value is accepted by `layout.const()` where `layout` is the layout of the `data.Const` object. This change greatly reduces boilerplate in tests by enabling e.g.: assert (await stream_get(ctx, stream)) == {"value": 1} instead of: assert (await stream_get(ctx, stream)) == Const({"value": 1}, stream.p.shape()) Note that, unlike `Layout.const`, which accepts arbitrary `Mapping` or `Sequence` objects, only `dict` and `list` are accepted in comparisons. Also, `data.View` continues to be comparable only to `data.View` and `data.Const`. This is to minimize the scope of the change and reduce likelihood of undesirable side effects when backported to the 0.5.x branch. Fixes #1414. --- amaranth/lib/data.py | 26 ++++++++++++++--- tests/test_lib_data.py | 63 ++++++++++++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/amaranth/lib/data.py b/amaranth/lib/data.py index 6f157e922..69ca86cf5 100644 --- a/amaranth/lib/data.py +++ b/amaranth/lib/data.py @@ -1083,9 +1083,18 @@ def __eq__(self, other): elif isinstance(other, Const) and self.__layout == other.__layout: return self.__target == other.__target else: + cause = None + if isinstance(other, (dict, list)): + try: + other_as_const = self.__layout.const(other) + except (TypeError, ValueError) as exc: + cause = exc + else: + return self == other_as_const raise TypeError( - f"Constant with layout {self.__layout!r} can only be compared to another view or " - f"constant with the same layout, not {other!r}") + f"Constant with layout {self.__layout!r} can only be compared to another view, " + f"a constant with the same layout, or a dictionary or a list that can be converted " + f"to a constant with the same layout, not {other!r}") from cause def __ne__(self, other): if isinstance(other, View) and self.__layout == other._View__layout: @@ -1093,9 +1102,18 @@ def __ne__(self, other): elif isinstance(other, Const) and self.__layout == other.__layout: return self.__target != other.__target else: + cause = None + if isinstance(other, (dict, list)): + try: + other_as_const = self.__layout.const(other) + except (TypeError, ValueError) as exc: + cause = exc + else: + return self != other_as_const raise TypeError( - f"Constant with layout {self.__layout!r} can only be compared to another view or " - f"constant with the same layout, not {other!r}") + f"Constant with layout {self.__layout!r} can only be compared to another view, " + f"a constant with the same layout, or a dictionary or a list that can be converted " + f"to a constant with the same layout, not {other!r}") from cause def __add__(self, other): raise TypeError("Cannot perform arithmetic operations on a lib.data.Const") diff --git a/tests/test_lib_data.py b/tests/test_lib_data.py index 4ec80b352..86a75c1da 100644 --- a/tests/test_lib_data.py +++ b/tests/test_lib_data.py @@ -744,7 +744,7 @@ def test_bug_837_array_layout_getattr(self): r"^View with an array layout does not have fields$"): Signal(data.ArrayLayout(unsigned(1), 1), init=[0]).init - def test_eq(self): + def test_compare(self): s1 = Signal(data.StructLayout({"a": unsigned(2)})) s2 = Signal(data.StructLayout({"a": unsigned(2)})) s3 = Signal(data.StructLayout({"a": unsigned(1), "b": unsigned(1)})) @@ -973,11 +973,12 @@ def test_bug_837_array_layout_getattr(self): r"^Constant with an array layout does not have fields$"): data.Const(data.ArrayLayout(unsigned(1), 1), 0).init - def test_eq(self): + def test_compare(self): c1 = data.Const(data.StructLayout({"a": unsigned(2)}), 1) c2 = data.Const(data.StructLayout({"a": unsigned(2)}), 1) c3 = data.Const(data.StructLayout({"a": unsigned(2)}), 2) c4 = data.Const(data.StructLayout({"a": unsigned(1), "b": unsigned(1)}), 2) + c5 = data.Const(data.ArrayLayout(2, 4), 0b11100100) s1 = Signal(data.StructLayout({"a": unsigned(2)})) self.assertTrue(c1 == c2) self.assertFalse(c1 != c2) @@ -987,13 +988,23 @@ def test_eq(self): self.assertRepr(c1 != s1, "(!= (const 2'd1) (sig s1))") self.assertRepr(s1 == c1, "(== (sig s1) (const 2'd1))") self.assertRepr(s1 != c1, "(!= (sig s1) (const 2'd1))") + self.assertTrue(c1 == {"a": 1}) + self.assertFalse(c1 == {"a": 2}) + self.assertFalse(c1 != {"a": 1}) + self.assertTrue(c1 != {"a": 2}) + self.assertTrue(c5 == [0,1,2,3]) + self.assertFalse(c5 == [0,1,3,3]) + self.assertFalse(c5 != [0,1,2,3]) + self.assertTrue(c5 != [0,1,3,3]) with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 == c4 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 != c4 with self.assertRaisesRegex(TypeError, r"^View with layout .* can only be compared to another view or constant with " @@ -1004,21 +1015,45 @@ def test_eq(self): r"the same layout, not .*$"): s1 != c4 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c4 == s1 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c4 != s1 with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 == Const(0, 2) with self.assertRaisesRegex(TypeError, - r"^Constant with layout .* can only be compared to another view or constant with " - r"the same layout, not .*$"): + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): c1 != Const(0, 2) + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c1 == {"b": 1} + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c1 != {"b": 1} + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c5 == [0,1,2,3,4] + with self.assertRaisesRegex(TypeError, + r"^Constant with layout .* can only be compared to another view, a constant " + r"with the same layout, or a dictionary or a list that can be converted to " + r"a constant with the same layout, not .*$"): + c5 != [0,1,2,3,4] def test_operator(self): s1 = data.Const(data.StructLayout({"a": unsigned(2)}), 2)