Skip to content

Commit fdb675f

Browse files
authored
feat: Expose BorrowArray in hugr-py (#2425)
Closes #2406
1 parent 71c0e89 commit fdb675f

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Borrow array types and operations."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import cast
7+
8+
import hugr.model as model
9+
from hugr import tys, val
10+
from hugr.std import _load_extension
11+
from hugr.utils import comma_sep_str
12+
13+
EXTENSION = _load_extension("collections.borrow_arr")
14+
15+
16+
@dataclass(eq=False)
17+
class BorrowArray(tys.ExtType):
18+
"""Fixed `size` borrow array of `ty` elements."""
19+
20+
def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None:
21+
if isinstance(size, int):
22+
size = tys.BoundedNatArg(size)
23+
24+
err_msg = (
25+
f"Borrow array size must be a bounded natural or a nat variable, not {size}"
26+
)
27+
match size:
28+
case tys.BoundedNatArg(_n):
29+
pass
30+
case tys.VariableArg(_idx, param):
31+
if not isinstance(param, tys.BoundedNatParam):
32+
raise ValueError(err_msg) # noqa: TRY004
33+
case _:
34+
raise ValueError(err_msg)
35+
36+
ty_arg = tys.TypeTypeArg(ty)
37+
38+
self.type_def = EXTENSION.types["borrow_array"]
39+
self.args = [size, ty_arg]
40+
41+
@property
42+
def ty(self) -> tys.Type:
43+
assert isinstance(
44+
self.args[1], tys.TypeTypeArg
45+
), "Borrow array elements must have a valid type"
46+
return self.args[1].ty
47+
48+
@property
49+
def size(self) -> int | None:
50+
"""If the borrow array has a concrete size, return it.
51+
52+
Otherwise, return None.
53+
"""
54+
if isinstance(self.args[0], tys.BoundedNatArg):
55+
return self.args[0].n
56+
return None
57+
58+
def type_bound(self) -> tys.TypeBound:
59+
return tys.TypeBound.Linear
60+
61+
62+
# Note that only borrow array values with no elements borrowed should be emitted.
63+
@dataclass
64+
class BorrowArrayVal(val.ExtensionValue):
65+
"""Constant value for a statically sized borrow array of elements."""
66+
67+
v: list[val.Value]
68+
ty: BorrowArray
69+
70+
def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None:
71+
self.v = v
72+
self.ty = BorrowArray(elem_ty, len(v))
73+
74+
def to_value(self) -> val.Extension:
75+
name = "BorrowArrayValue"
76+
# The value list must be serialized at this point, otherwise the
77+
# `Extension` value would not be serializable.
78+
vs = [v._to_serial_root() for v in self.v]
79+
element_ty = self.ty.ty._to_serial_root()
80+
serial_val = {"values": vs, "typ": element_ty}
81+
return val.Extension(name, typ=self.ty, val=serial_val)
82+
83+
def __str__(self) -> str:
84+
return f"borrow_array({comma_sep_str(self.v)})"
85+
86+
def to_model(self) -> model.Term:
87+
return model.Apply(
88+
"collections.borrow_array.const",
89+
[
90+
model.Literal(len(self.v)),
91+
cast(model.Term, self.ty.ty.to_model()),
92+
model.List([value.to_model() for value in self.v]),
93+
],
94+
)

hugr-py/tests/test_tys.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from hugr import val
66
from hugr.std.collections.array import Array, ArrayVal
7+
from hugr.std.collections.borrow_array import BorrowArray, BorrowArrayVal
78
from hugr.std.collections.list import List, ListVal
89
from hugr.std.collections.static_array import StaticArray, StaticArrayVal
910
from hugr.std.collections.value_array import ValueArray, ValueArrayVal
@@ -136,6 +137,7 @@ def test_args_str(arg: TypeArg, string: str):
136137
(Array(Bool, 3), "array<3, Type(Bool)>"),
137138
(StaticArray(Bool), "static_array<Type(Bool)>"),
138139
(ValueArray(Bool, 3), "value_array<3, Type(Bool)>"),
140+
(BorrowArray(Bool, 3), "borrow_array<3, Type(Bool)>"),
139141
(Variable(2, TypeBound.Linear), "$2"),
140142
(RowVariable(4, TypeBound.Copyable), "$4"),
141143
(USize(), "USize"),
@@ -210,6 +212,25 @@ def test_value_array():
210212
assert ar_val.ty == ValueArray(Bool, 2)
211213

212214

215+
def test_borrow_array():
216+
ty_var = Variable(0, TypeBound.Copyable)
217+
len_var = VariableArg(1, BoundedNatParam())
218+
219+
ls = BorrowArray(Bool, 3)
220+
assert ls.ty == Bool
221+
assert ls.size == 3
222+
assert ls.type_bound() == TypeBound.Linear
223+
224+
ls = BorrowArray(ty_var, len_var)
225+
assert ls.ty == ty_var
226+
assert ls.size is None
227+
assert ls.type_bound() == TypeBound.Linear
228+
229+
ar_val = BorrowArrayVal([val.TRUE, val.FALSE], Bool)
230+
assert ar_val.v == [val.TRUE, val.FALSE]
231+
assert ar_val.ty == BorrowArray(Bool, 2)
232+
233+
213234
def test_static_array():
214235
ty_var = Variable(0, TypeBound.Copyable)
215236

0 commit comments

Comments
 (0)