Skip to content

Commit ec207e7

Browse files
authored
feat: Standarize the string formating of sum types and values (#2432)
The Sum value and type helpers get coerced back into `val.Sum` and `tys.Sum` after a serialization roundtrip. This used to change the rendering of the values, falling back to the verbose `Sum(tag=#, typ=[..., ...], vals=[...,...])`. This PR centralizes the str/repr definition, so equivalent values and types are always rendered in the same way. (I needed this for roundtrip checks).
1 parent fdb675f commit ec207e7

File tree

4 files changed

+135
-74
lines changed

4 files changed

+135
-74
lines changed

hugr-py/src/hugr/tys.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import hugr._serialization.tys as stys
1010
import hugr.model as model
11-
from hugr.utils import comma_sep_repr, comma_sep_str, ser_it
11+
from hugr.utils import comma_sep_repr, comma_sep_str, comma_sep_str_paren, ser_it
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import Iterable, Sequence
@@ -430,7 +430,38 @@ def as_tuple(self) -> Tuple:
430430
return Tuple(*self.variant_rows[0])
431431

432432
def __repr__(self) -> str:
433-
return f"Sum({self.variant_rows})"
433+
if self == Bool:
434+
return "Bool"
435+
elif self == Unit:
436+
return "Unit"
437+
elif all(len(row) == 0 for row in self.variant_rows):
438+
return f"UnitSum({len(self.variant_rows)})"
439+
elif len(self.variant_rows) == 1:
440+
return f"Tuple{tuple(self.variant_rows[0])}"
441+
elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0:
442+
return f"Option({comma_sep_repr(self.variant_rows[1])})"
443+
elif len(self.variant_rows) == 2:
444+
left, right = self.variant_rows
445+
return f"Either(left={left}, right={right})"
446+
else:
447+
return f"Sum({self.variant_rows})"
448+
449+
def __str__(self) -> str:
450+
if self == Bool:
451+
return "Bool"
452+
elif self == Unit:
453+
return "Unit"
454+
elif all(len(row) == 0 for row in self.variant_rows):
455+
return f"UnitSum({len(self.variant_rows)})"
456+
elif len(self.variant_rows) == 1:
457+
return f"Tuple{tuple(self.variant_rows[0])}"
458+
elif len(self.variant_rows) == 2 and len(self.variant_rows[0]) == 0:
459+
return f"Option({comma_sep_str(self.variant_rows[1])})"
460+
elif len(self.variant_rows) == 2:
461+
left, right = self.variant_rows
462+
return f"Either({comma_sep_str_paren(left)}, {comma_sep_str_paren(right)})"
463+
else:
464+
return f"Sum({self.variant_rows})"
434465

435466
def __eq__(self, other: object) -> bool:
436467
return isinstance(other, Sum) and self.variant_rows == other.variant_rows
@@ -449,7 +480,7 @@ def to_model(self) -> model.Term:
449480
return model.Apply("core.adt", [variants])
450481

451482

452-
@dataclass(eq=False)
483+
@dataclass(eq=False, repr=False)
453484
class UnitSum(Sum):
454485
"""Simple :class:`Sum` type with `size` variants of empty rows."""
455486

@@ -462,18 +493,14 @@ def __init__(self, size: int):
462493
def _to_serial(self) -> stys.UnitSum: # type: ignore[override]
463494
return stys.UnitSum(size=self.size)
464495

465-
def __repr__(self) -> str:
466-
if self == Bool:
467-
return "Bool"
468-
elif self == Unit:
469-
return "Unit"
470-
return f"UnitSum({self.size})"
471-
472496
def resolve(self, registry: ext.ExtensionRegistry) -> UnitSum:
473497
return self
474498

499+
def __str__(self) -> str:
500+
return self.__repr__()
501+
475502

476-
@dataclass(eq=False)
503+
@dataclass(eq=False, repr=False)
477504
class Tuple(Sum):
478505
"""Product type with `tys` elements. Instances of this type correspond to
479506
:class:`Sum` with a single variant.
@@ -482,11 +509,8 @@ class Tuple(Sum):
482509
def __init__(self, *tys: Type):
483510
self.variant_rows = [list(tys)]
484511

485-
def __repr__(self) -> str:
486-
return f"Tuple{tuple(self.variant_rows[0])}"
487-
488512

489-
@dataclass(eq=False)
513+
@dataclass(eq=False, repr=False)
490514
class Option(Sum):
491515
"""Optional tuple of elements.
492516
@@ -497,11 +521,8 @@ class Option(Sum):
497521
def __init__(self, *tys: Type):
498522
self.variant_rows = [[], list(tys)]
499523

500-
def __repr__(self) -> str:
501-
return f"Option({comma_sep_repr(self.variant_rows[1])})"
502-
503524

504-
@dataclass(eq=False)
525+
@dataclass(eq=False, repr=False)
505526
class Either(Sum):
506527
"""Two-variant tuple of elements.
507528
@@ -514,16 +535,6 @@ class Either(Sum):
514535
def __init__(self, left: Iterable[Type], right: Iterable[Type]):
515536
self.variant_rows = [list(left), list(right)]
516537

517-
def __repr__(self) -> str: # pragma: no cover
518-
left, right = self.variant_rows
519-
return f"Either(left={left}, right={right})"
520-
521-
def __str__(self) -> str:
522-
left, right = self.variant_rows
523-
left_str = left[0] if len(left) == 1 else tuple(left)
524-
right_str = right[0] if len(right) == 1 else tuple(right)
525-
return f"Either({left_str}, {right_str})"
526-
527538

528539
@dataclass(frozen=True)
529540
class Variable(Type):

hugr-py/src/hugr/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,27 @@ def comma_sep_str(items: Iterable[T]) -> str:
215215
def comma_sep_repr(items: Iterable[T]) -> str:
216216
"""Join items with commas and repr."""
217217
return ", ".join(map(repr, items))
218+
219+
220+
def comma_sep_str_paren(items: Iterable[T]) -> str:
221+
"""Join items with commas and str, wrapping them in parentheses if more than one."""
222+
items = list(items)
223+
if len(items) == 0:
224+
return "()"
225+
elif len(items) == 1:
226+
return f"{items[0]}"
227+
else:
228+
return f"({comma_sep_str(items)})"
229+
230+
231+
def comma_sep_repr_paren(items: Iterable[T]) -> str:
232+
"""Join items with commas and repr, wrapping them in parentheses if more
233+
than one.
234+
"""
235+
items = list(items)
236+
if len(items) == 0:
237+
return "()"
238+
elif len(items) == 1:
239+
return f"{items[0]}"
240+
else:
241+
return f"({comma_sep_repr(items)})"

hugr-py/src/hugr/val.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class Sum(Value):
4545
"""Sum-of-product value.
4646
4747
Example:
48-
>>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit]]), [TRUE])
49-
Sum(tag=0, typ=Sum([[Bool], [Unit]]), vals=[TRUE])
48+
>>> Sum(0, tys.Sum([[tys.Bool], [tys.Unit], [tys.Bool]]), [TRUE])
49+
Sum(tag=0, typ=Sum([[Bool], [Unit], [Bool]]), vals=[TRUE])
5050
"""
5151

5252
#: Tag identifying the variant.
@@ -70,6 +70,59 @@ def _to_serial(self) -> sops.SumValue:
7070
vs=ser_it(self.vals),
7171
)
7272

73+
def __repr__(self) -> str:
74+
if self == TRUE:
75+
return "TRUE"
76+
elif self == FALSE:
77+
return "FALSE"
78+
elif self == Unit:
79+
return "Unit"
80+
elif all(len(row) == 0 for row in self.typ.variant_rows):
81+
return f"UnitSum({self.tag}, {self.n_variants})"
82+
elif len(self.typ.variant_rows) == 1:
83+
return f"Tuple({comma_sep_repr(self.vals)})"
84+
elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0:
85+
# Option
86+
if self.tag == 0:
87+
return f"None({comma_sep_str(self.typ.variant_rows[1])})"
88+
else:
89+
return f"Some({comma_sep_repr(self.vals)})"
90+
elif len(self.typ.variant_rows) == 2:
91+
# Either
92+
left_typ, right_typ = self.typ.variant_rows
93+
if self.tag == 0:
94+
return f"Left(vals={self.vals}, right_typ={list(right_typ)})"
95+
else:
96+
return f"Right(left_typ={list(left_typ)}, vals={self.vals})"
97+
else:
98+
return f"Sum(tag={self.tag}, typ={self.typ}, vals={self.vals})"
99+
100+
def __str__(self) -> str:
101+
if self == TRUE:
102+
return "TRUE"
103+
elif self == FALSE:
104+
return "FALSE"
105+
elif self == Unit:
106+
return "Unit"
107+
elif all(len(row) == 0 for row in self.typ.variant_rows):
108+
return f"UnitSum({self.tag}, {self.n_variants})"
109+
elif len(self.typ.variant_rows) == 1:
110+
return f"Tuple({comma_sep_str(self.vals)})"
111+
elif len(self.typ.variant_rows) == 2 and len(self.typ.variant_rows[0]) == 0:
112+
# Option
113+
if self.tag == 0:
114+
return "None"
115+
else:
116+
return f"Some({comma_sep_str(self.vals)})"
117+
elif len(self.typ.variant_rows) == 2:
118+
# Either
119+
if self.tag == 0:
120+
return f"Left({comma_sep_str(self.vals)})"
121+
else:
122+
return f"Right({comma_sep_str(self.vals)})"
123+
else:
124+
return f"Sum({self.tag}, {self.typ}, {self.vals})"
125+
73126
def __eq__(self, other: object) -> bool:
74127
return (
75128
isinstance(other, Sum)
@@ -100,6 +153,7 @@ def to_model(self) -> model.Term:
100153
)
101154

102155

156+
@dataclass(eq=False, repr=False)
103157
class UnitSum(Sum):
104158
"""Simple :class:`Sum` with each variant being an empty row.
105159
@@ -119,15 +173,6 @@ def __init__(self, tag: int, size: int):
119173
vals=[],
120174
)
121175

122-
def __repr__(self) -> str:
123-
if self == TRUE:
124-
return "TRUE"
125-
if self == FALSE:
126-
return "FALSE"
127-
if self == Unit:
128-
return "Unit"
129-
return f"UnitSum({self.tag}, {self.n_variants})"
130-
131176

132177
def bool_value(b: bool) -> UnitSum:
133178
"""Convert a python bool to a HUGR boolean value.
@@ -149,7 +194,7 @@ def bool_value(b: bool) -> UnitSum:
149194
FALSE = bool_value(False)
150195

151196

152-
@dataclass(eq=False)
197+
@dataclass(eq=False, repr=False)
153198
class Tuple(Sum):
154199
"""Tuple or product value, defined by a list of values.
155200
Internally a :class:`Sum` with a single variant row.
@@ -177,10 +222,10 @@ def _to_serial(self) -> sops.TupleValue: # type: ignore[override]
177222
)
178223

179224
def __repr__(self) -> str:
180-
return f"Tuple({comma_sep_repr(self.vals)})"
225+
return super().__repr__()
181226

182227

183-
@dataclass(eq=False)
228+
@dataclass(eq=False, repr=False)
184229
class Some(Sum):
185230
"""Optional tuple of value, containing a list of values.
186231
@@ -199,11 +244,8 @@ def __init__(self, *vals: Value):
199244
tag=1, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list
200245
)
201246

202-
def __repr__(self) -> str:
203-
return f"Some({comma_sep_repr(self.vals)})"
204-
205247

206-
@dataclass(eq=False)
248+
@dataclass(eq=False, repr=False)
207249
class None_(Sum):
208250
"""Optional tuple of value, containing no values.
209251
@@ -219,14 +261,8 @@ class None_(Sum):
219261
def __init__(self, *types: tys.Type):
220262
super().__init__(tag=0, typ=tys.Option(*types), vals=[])
221263

222-
def __repr__(self) -> str:
223-
return f"None({comma_sep_str(self.typ.variant_rows[1])})"
224-
225-
def __str__(self) -> str:
226-
return "None"
227-
228264

229-
@dataclass(eq=False)
265+
@dataclass(eq=False, repr=False)
230266
class Left(Sum):
231267
"""Left variant of a :class:`tys.Either` type, containing a list of values.
232268
@@ -248,15 +284,8 @@ def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]):
248284
vals=val_list,
249285
)
250286

251-
def __repr__(self) -> str:
252-
_, right_typ = self.typ.variant_rows
253-
return f"Left(vals={self.vals}, right_typ={list(right_typ)})"
254-
255-
def __str__(self) -> str:
256-
return f"Left({comma_sep_str(self.vals)})"
257-
258287

259-
@dataclass(eq=False)
288+
@dataclass(eq=False, repr=False)
260289
class Right(Sum):
261290
"""Right variant of a :class:`tys.Either` type, containing a list of values.
262291
@@ -280,13 +309,6 @@ def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]):
280309
vals=val_list,
281310
)
282311

283-
def __repr__(self) -> str:
284-
left_typ, _ = self.typ.variant_rows
285-
return f"Right(left_typ={list(left_typ)}, vals={self.vals})"
286-
287-
def __str__(self) -> str:
288-
return f"Right({comma_sep_str(self.vals)})"
289-
290312

291313
@dataclass
292314
class Function(Value):

hugr-py/tests/test_val.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Sum,
1515
Tuple,
1616
UnitSum,
17-
Value,
1817
bool_value,
1918
)
2019

@@ -44,9 +43,9 @@ def test_sums():
4443
("value", "string", "repr_str"),
4544
[
4645
(
47-
Sum(0, tys.Sum([[tys.Bool], [tys.Qubit]]), [TRUE, FALSE]),
48-
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
49-
"Sum(tag=0, typ=Sum([[Bool], [Qubit]]), vals=[TRUE, FALSE])",
46+
Sum(0, tys.Sum([[tys.Bool], [tys.Qubit], [tys.Bool]]), [TRUE]),
47+
"Sum(0, Sum([[Bool], [Qubit], [Bool]]), [TRUE])",
48+
"Sum(tag=0, typ=Sum([[Bool], [Qubit], [Bool]]), vals=[TRUE])",
5049
),
5150
(UnitSum(0, size=1), "Unit", "Unit"),
5251
(UnitSum(0, size=2), "FALSE", "FALSE"),
@@ -67,10 +66,15 @@ def test_sums():
6766
),
6867
],
6968
)
70-
def test_val_sum_str(value: Value, string: str, repr_str: str):
69+
def test_val_sum_str(value: Sum, string: str, repr_str: str):
7170
assert str(value) == string
7271
assert repr(value) == repr_str
7372

73+
# Make sure the corresponding `Sum` also renders the same
74+
sum_val = Sum(value.tag, value.typ, value.vals)
75+
assert str(sum_val) == string
76+
assert repr(sum_val) == repr_str
77+
7478

7579
def test_val_static_array():
7680
from hugr.std.collections.static_array import StaticArrayVal

0 commit comments

Comments
 (0)