diff --git a/hugr-py/src/hugr/std/collections/borrow_array.py b/hugr-py/src/hugr/std/collections/borrow_array.py new file mode 100644 index 000000000..9d01f86e6 --- /dev/null +++ b/hugr-py/src/hugr/std/collections/borrow_array.py @@ -0,0 +1,94 @@ +"""Borrow array types and operations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import cast + +import hugr.model as model +from hugr import tys, val +from hugr.std import _load_extension +from hugr.utils import comma_sep_str + +EXTENSION = _load_extension("collections.borrow_arr") + + +@dataclass(eq=False) +class BorrowArray(tys.ExtType): + """Fixed `size` borrow array of `ty` elements.""" + + def __init__(self, ty: tys.Type, size: int | tys.TypeArg) -> None: + if isinstance(size, int): + size = tys.BoundedNatArg(size) + + err_msg = ( + f"Borrow array size must be a bounded natural or a nat variable, not {size}" + ) + match size: + case tys.BoundedNatArg(_n): + pass + case tys.VariableArg(_idx, param): + if not isinstance(param, tys.BoundedNatParam): + raise ValueError(err_msg) # noqa: TRY004 + case _: + raise ValueError(err_msg) + + ty_arg = tys.TypeTypeArg(ty) + + self.type_def = EXTENSION.types["borrow_array"] + self.args = [size, ty_arg] + + @property + def ty(self) -> tys.Type: + assert isinstance( + self.args[1], tys.TypeTypeArg + ), "Borrow array elements must have a valid type" + return self.args[1].ty + + @property + def size(self) -> int | None: + """If the borrow array has a concrete size, return it. + + Otherwise, return None. + """ + if isinstance(self.args[0], tys.BoundedNatArg): + return self.args[0].n + return None + + def type_bound(self) -> tys.TypeBound: + return tys.TypeBound.Linear + + +# Note that only borrow array values with no elements borrowed should be emitted. +@dataclass +class BorrowArrayVal(val.ExtensionValue): + """Constant value for a statically sized borrow array of elements.""" + + v: list[val.Value] + ty: BorrowArray + + def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: + self.v = v + self.ty = BorrowArray(elem_ty, len(v)) + + def to_value(self) -> val.Extension: + name = "BorrowArrayValue" + # The value list must be serialized at this point, otherwise the + # `Extension` value would not be serializable. + vs = [v._to_serial_root() for v in self.v] + element_ty = self.ty.ty._to_serial_root() + serial_val = {"values": vs, "typ": element_ty} + return val.Extension(name, typ=self.ty, val=serial_val) + + def __str__(self) -> str: + return f"borrow_array({comma_sep_str(self.v)})" + + def to_model(self) -> model.Term: + return model.Apply( + "collections.borrow_array.const", + [ + model.Literal(len(self.v)), + cast(model.Term, self.ty.ty.to_model()), + model.List([value.to_model() for value in self.v]), + ], + ) diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index 5766d0a84..955a26ec6 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -4,6 +4,7 @@ from hugr import val from hugr.std.collections.array import Array, ArrayVal +from hugr.std.collections.borrow_array import BorrowArray, BorrowArrayVal from hugr.std.collections.list import List, ListVal from hugr.std.collections.static_array import StaticArray, StaticArrayVal from hugr.std.collections.value_array import ValueArray, ValueArrayVal @@ -136,6 +137,7 @@ def test_args_str(arg: TypeArg, string: str): (Array(Bool, 3), "array<3, Type(Bool)>"), (StaticArray(Bool), "static_array"), (ValueArray(Bool, 3), "value_array<3, Type(Bool)>"), + (BorrowArray(Bool, 3), "borrow_array<3, Type(Bool)>"), (Variable(2, TypeBound.Linear), "$2"), (RowVariable(4, TypeBound.Copyable), "$4"), (USize(), "USize"), @@ -210,6 +212,25 @@ def test_value_array(): assert ar_val.ty == ValueArray(Bool, 2) +def test_borrow_array(): + ty_var = Variable(0, TypeBound.Copyable) + len_var = VariableArg(1, BoundedNatParam()) + + ls = BorrowArray(Bool, 3) + assert ls.ty == Bool + assert ls.size == 3 + assert ls.type_bound() == TypeBound.Linear + + ls = BorrowArray(ty_var, len_var) + assert ls.ty == ty_var + assert ls.size is None + assert ls.type_bound() == TypeBound.Linear + + ar_val = BorrowArrayVal([val.TRUE, val.FALSE], Bool) + assert ar_val.v == [val.TRUE, val.FALSE] + assert ar_val.ty == BorrowArray(Bool, 2) + + def test_static_array(): ty_var = Variable(0, TypeBound.Copyable)