Skip to content

feat: Expose BorrowArray in hugr-py #2425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions hugr-py/src/hugr/std/collections/borrow_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Borrow array types and operations."""

from __future__ import annotations

from dataclasses import dataclass
from typing import cast

import hugr.model as model
from hugr import tys, val
from hugr.std import _load_extension
from hugr.utils import comma_sep_str

EXTENSION = _load_extension("collections.borrow_arr")


@dataclass(eq=False)
class BorrowArray(tys.ExtType):
"""Fixed `size` borrow array of `ty` elements."""

def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None:
if isinstance(size, int):
size = tys.BoundedNatArg(size)

err_msg = (
f"Borrow array size must be a bounded natural or a nat variable, not {size}"
)
match size:
case tys.BoundedNatArg(_n):
pass
case tys.VariableArg(_idx, param):
if not isinstance(param, tys.BoundedNatParam):
raise ValueError(err_msg) # noqa: TRY004
case _:
raise ValueError(err_msg)

Check warning on line 34 in hugr-py/src/hugr/std/collections/borrow_array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/borrow_array.py#L32-L34

Added lines #L32 - L34 were not covered by tests

ty_arg = tys.TypeTypeArg(ty)

self.type_def = EXTENSION.types["borrow_array"]
self.args = [size, ty_arg]

@property
def ty(self) -> tys.Type:
assert isinstance(
self.args[1], tys.TypeTypeArg
), "Borrow array elements must have a valid type"
return self.args[1].ty

@property
def size(self) -> int | None:
"""If the borrow array has a concrete size, return it.

Otherwise, return None.
"""
if isinstance(self.args[0], tys.BoundedNatArg):
return self.args[0].n
return None

def type_bound(self) -> tys.TypeBound:
return tys.TypeBound.Linear


# Note that only borrow array values with no elements borrowed should be emitted.
Copy link
Contributor

@lmondada lmondada Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is any way this could be asserted in this code or tested below, maybe not? For me, not being familiar with this code, this comment makes me wonder what the restrictions mean concretely in the kind of code that I would be allowed to write. Might just be my ignorance though, so feel free to ignore :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is something that should be handled in the future in some form (#2437)

@dataclass
class BorrowArrayVal(val.ExtensionValue):
"""Constant value for a statically sized borrow array of elements."""

v: list[val.Value]
ty: BorrowArray

def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None:
self.v = v
self.ty = BorrowArray(elem_ty, len(v))

def to_value(self) -> val.Extension:
name = "BorrowArrayValue"

Check warning on line 75 in hugr-py/src/hugr/std/collections/borrow_array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/borrow_array.py#L75

Added line #L75 was not covered by tests
# The value list must be serialized at this point, otherwise the
# `Extension` value would not be serializable.
vs = [v._to_serial_root() for v in self.v]
element_ty = self.ty.ty._to_serial_root()
serial_val = {"values": vs, "typ": element_ty}
return val.Extension(name, typ=self.ty, val=serial_val)

Check warning on line 81 in hugr-py/src/hugr/std/collections/borrow_array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/borrow_array.py#L78-L81

Added lines #L78 - L81 were not covered by tests

def __str__(self) -> str:
return f"borrow_array({comma_sep_str(self.v)})"

Check warning on line 84 in hugr-py/src/hugr/std/collections/borrow_array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/borrow_array.py#L84

Added line #L84 was not covered by tests

def to_model(self) -> model.Term:
return model.Apply(

Check warning on line 87 in hugr-py/src/hugr/std/collections/borrow_array.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/std/collections/borrow_array.py#L87

Added line #L87 was not covered by tests
"collections.borrow_array.const",
[
model.Literal(len(self.v)),
cast(model.Term, self.ty.ty.to_model()),
model.List([value.to_model() for value in self.v]),
],
)
21 changes: 21 additions & 0 deletions hugr-py/tests/test_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from hugr import val
from hugr.std.collections.array import Array, ArrayVal
from hugr.std.collections.borrow_array import BorrowArray, BorrowArrayVal
from hugr.std.collections.list import List, ListVal
from hugr.std.collections.static_array import StaticArray, StaticArrayVal
from hugr.std.collections.value_array import ValueArray, ValueArrayVal
Expand Down Expand Up @@ -136,6 +137,7 @@ def test_args_str(arg: TypeArg, string: str):
(Array(Bool, 3), "array<3, Type(Bool)>"),
(StaticArray(Bool), "static_array<Type(Bool)>"),
(ValueArray(Bool, 3), "value_array<3, Type(Bool)>"),
(BorrowArray(Bool, 3), "borrow_array<3, Type(Bool)>"),
(Variable(2, TypeBound.Linear), "$2"),
(RowVariable(4, TypeBound.Copyable), "$4"),
(USize(), "USize"),
Expand Down Expand Up @@ -210,6 +212,25 @@ def test_value_array():
assert ar_val.ty == ValueArray(Bool, 2)


def test_borrow_array():
ty_var = Variable(0, TypeBound.Copyable)
len_var = VariableArg(1, BoundedNatParam())

ls = BorrowArray(Bool, 3)
assert ls.ty == Bool
assert ls.size == 3
assert ls.type_bound() == TypeBound.Linear

ls = BorrowArray(ty_var, len_var)
assert ls.ty == ty_var
assert ls.size is None
assert ls.type_bound() == TypeBound.Linear

ar_val = BorrowArrayVal([val.TRUE, val.FALSE], Bool)
assert ar_val.v == [val.TRUE, val.FALSE]
assert ar_val.ty == BorrowArray(Bool, 2)


def test_static_array():
ty_var = Variable(0, TypeBound.Copyable)

Expand Down
Loading