Skip to content

feat: support storing tuples in state #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/_algopy_testing/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,12 +742,18 @@ def __repr__(self) -> str:


class _DynamicArrayTypeInfo(_TypeInfo):
def __init__(self, item_type: _TypeInfo):
_subclass_type: Callable[[], type] | None

def __init__(self, item_type: _TypeInfo, subclass_type: Callable[[], type] | None = None):
self._subclass_type = subclass_type
self.item_type = item_type

@property
def typ(self) -> type:
return _parameterize_type(DynamicArray, self.item_type.typ)
if self._subclass_type is not None:
return self._subclass_type()
else:
return _parameterize_type(DynamicArray, self.item_type.typ)

@property
def arc4_name(self) -> str:
Expand Down Expand Up @@ -891,6 +897,10 @@ def __repr__(self) -> str:
class DynamicBytes(DynamicArray[Byte]):
"""A variable sized array of bytes."""

_type_info: _DynamicArrayTypeInfo = _DynamicArrayTypeInfo(
Byte._type_info, lambda: DynamicBytes
)

@typing.overload
def __init__(self, *values: Byte | UInt8 | int): ...

Expand Down Expand Up @@ -996,6 +1006,12 @@ def __init__(self, _items: tuple[typing.Unpack[_TTuple]] = (), /): # type: igno
)
self._value = _encode(items)

def __bool__(self) -> bool:
try:
return bool(self.native)
except ValueError:
return False

def __len__(self) -> int:
return len(self.native)

Expand Down Expand Up @@ -1103,6 +1119,8 @@ def _update_backing_value(self) -> None:
def from_bytes(cls, value: algopy.Bytes | bytes, /) -> typing.Self:
tuple_type = _tuple_type_from_struct(cls)
tuple_value = tuple_type.from_bytes(value)
if not tuple_value:
return typing.cast(typing.Self, tuple_value)
return cls(*tuple_value.native)

@property
Expand Down
5 changes: 3 additions & 2 deletions src/_algopy_testing/models/contract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import inspect
import typing
from dataclasses import dataclass

Expand Down Expand Up @@ -201,12 +202,12 @@ def _get_state_totals(contract: Contract, cls_state_totals: StateTotals) -> _Sta

global_bytes = global_uints = local_bytes = local_uints = 0
for type_ in get_global_states(contract).values():
if issubclass(type_, UInt64 | UInt64Backed | bool):
if inspect.isclass(type_) and issubclass(type_, UInt64 | UInt64Backed | bool):
global_uints += 1
else:
global_bytes += 1
for type_ in get_local_states(contract).values():
if issubclass(type_, UInt64 | UInt64Backed | bool):
if inspect.isclass(type_) and issubclass(type_, UInt64 | UInt64Backed | bool):
local_uints += 1
else:
local_bytes += 1
Expand Down
11 changes: 9 additions & 2 deletions src/_algopy_testing/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def native_to_arc4(value: object) -> "_ABIEncoded":
return arc4_value


def compare_type(value_type: type, typ: type) -> bool:
if typing.NamedTuple in getattr(typ, "__orig_bases__", []):
tuple_fields: Sequence[type] = list(inspect.get_annotations(typ).values())
typ = tuple[*tuple_fields] # type: ignore[valid-type]
return value_type == typ


def deserialize_from_bytes(typ: type[_T], bites: bytes) -> _T:
serializer = get_native_to_arc4_serializer(typ)
arc4_value = serializer.arc4_type.from_bytes(bites)
native_value = serializer.arc4_to_native(arc4_value)
assert isinstance(native_value, typ)
return native_value
assert compare_type(type_of(native_value), typ) or isinstance(native_value, typ)
return native_value # type: ignore[no-any-return]
6 changes: 3 additions & 3 deletions src/_algopy_testing/state/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def get(self, key: _TKey, *, default: _TValue) -> _TValue:
def maybe(self, key: _TKey) -> tuple[_TValue, bool]:
key_bytes = self._full_key(key)
box_exists = lazy_context.ledger.box_exists(self.app_id, key_bytes)
if not box_exists:
return self._value_type(), False
box_content_bytes = lazy_context.ledger.get_box(self.app_id, key_bytes)
box_content_bytes = (
b"" if not box_exists else lazy_context.ledger.get_box(self.app_id, key_bytes)
)
box_content = cast_from_bytes(self._value_type, box_content_bytes)
return box_content, box_exists

Expand Down
11 changes: 5 additions & 6 deletions src/_algopy_testing/state/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from _algopy_testing.context_helpers import lazy_context
from _algopy_testing.mutable import set_attr_on_mutate
from _algopy_testing.primitives import Bytes, String
from _algopy_testing.serialize import type_of
from _algopy_testing.state.utils import deserialize, serialize

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -49,10 +50,10 @@ def __init__(
self._key: Bytes | None = None
self._pending_value: _T | None = None

if isinstance(type_or_value, type):
self.type_: type[_T] = type_or_value
if isinstance(type_or_value, type) or isinstance(typing.get_origin(type_or_value), type):
self.type_: type[_T] = typing.cast(type[_T], type_or_value)
else:
self.type_ = type(type_or_value)
self.type_ = type_of(type_or_value)
self._pending_value = type_or_value

self.set_key(key)
Expand Down Expand Up @@ -123,9 +124,7 @@ def get(self, default: _T | None = None) -> _T:
try:
return self.value
except ValueError:
if default is not None:
return default
return self.type_()
return typing.cast(_T, default)

def maybe(self) -> tuple[_T | None, bool]:
try:
Expand Down
4 changes: 2 additions & 2 deletions src/_algopy_testing/state/local_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def get(self, key: algopy.Account | algopy.UInt64 | int, default: _T | None = No
try:
return self[account]
except KeyError:
return default if default is not None else self.type_()
return typing.cast(_T, default)

def maybe(self, key: algopy.Account | algopy.UInt64 | int) -> tuple[_T, bool]:
account = _get_account(key)
try:
return self[account], True
except KeyError:
return self.type_(), False
return typing.cast(_T, None), False


# TODO: make a util function along with one used by ops
Expand Down
13 changes: 11 additions & 2 deletions src/_algopy_testing/state/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

import inspect
import typing

from _algopy_testing.primitives.bytes import Bytes
from _algopy_testing.primitives.uint64 import UInt64
from _algopy_testing.protocols import BytesBacked, Serializable, UInt64Backed
from _algopy_testing.serialize import (
deserialize_from_bytes,
serialize_to_bytes,
)

_TValue = typing.TypeVar("_TValue")
SerializableValue = int | bytes
Expand All @@ -21,12 +26,16 @@ def serialize(value: _TValue) -> SerializableValue:
return value.bytes.value
elif isinstance(value, Serializable):
return value.serialize()
elif isinstance(value, tuple):
return serialize_to_bytes(value)
else:
raise TypeError(f"Unsupported type: {type(value)}")


def deserialize(typ: type[_TValue], value: SerializableValue) -> _TValue:
if issubclass(typ, bool):
if (typing.get_origin(typ) is tuple or issubclass(typ, tuple)) and isinstance(value, bytes):
return () if not value else deserialize_from_bytes(typ, value) # type: ignore[return-value]
elif issubclass(typ, bool):
return value != 0 # type: ignore[return-value]
elif issubclass(typ, UInt64 | Bytes):
return typ(value) # type: ignore[arg-type, return-value]
Expand Down Expand Up @@ -55,7 +64,7 @@ def cast_from_bytes(typ: type[_TValue], value: bytes) -> _TValue:
"""
from _algopy_testing.utils import as_int64

if issubclass(typ, bool | UInt64Backed | UInt64):
if inspect.isclass(typ) and issubclass(typ, bool | UInt64Backed | UInt64):
if len(value) > 8:
raise ValueError("uint64 value too big")
serialized: SerializableValue = int.from_bytes(value)
Expand Down
Loading