Skip to content

Commit 465a98a

Browse files
Feature/cleanup and guardrails (#34)
Add typing guardrails
1 parent 5f6b77e commit 465a98a

File tree

12 files changed

+306
-114
lines changed

12 files changed

+306
-114
lines changed

nada_algebra/array.py

Lines changed: 150 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55

66
# pylint:disable=too-many-lines
77

8-
from types import NoneType
9-
from typing import Any, Callable, Optional, Sequence, Type, Union
8+
from typing import Any, Callable, Optional, Sequence, Union, get_args, overload
109

1110
import numpy as np
1211
from nada_dsl import (Input, Integer, Output, Party, PublicInteger,
1312
PublicUnsignedInteger, SecretInteger,
1413
SecretUnsignedInteger, UnsignedInteger)
1514

1615
from nada_algebra.context import UnsafeArithmeticSession
16+
from nada_algebra.nada_typing import (NadaBoolean, NadaInteger, NadaRational,
17+
NadaUnsignedInteger)
1718
from nada_algebra.types import (Rational, SecretRational, get_log_scale,
1819
public_rational, rational, secret_rational)
1920
from nada_algebra.utils import copy_metadata
@@ -22,9 +23,6 @@
2223
class NadaArray: # pylint:disable=too-many-public-methods
2324
"""
2425
Represents an array-like object with additional functionality.
25-
26-
Attributes:
27-
inner (np.ndarray): The underlying NumPy array.
2826
"""
2927

3028
def __init__(self, inner: np.ndarray):
@@ -41,6 +39,7 @@ def __init__(self, inner: np.ndarray):
4139
raise ValueError(f"inner must be a numpy array and is: {type(inner)}")
4240
if isinstance(inner, NadaArray):
4341
inner = inner.inner
42+
_check_type_conflicts(inner)
4443
self.inner = inner
4544

4645
def __getitem__(self, item):
@@ -65,7 +64,11 @@ def __setitem__(self, key, value):
6564
Args:
6665
key: The key to set.
6766
value: The value to set.
67+
68+
Raises:
69+
ValueError: Raised when value with incompatible type is passed.
6870
"""
71+
_check_type_compatibility(value, self.dtype)
6972
if isinstance(value, NadaArray):
7073
self.inner[key] = value.inner
7174
else:
@@ -462,7 +465,7 @@ def apply(self, func: Callable[[Any], Any]) -> "NadaArray":
462465
def mean(self, axis=None, dtype=None, out=None) -> Any:
463466
sum_arr = self.inner.sum(axis=axis, dtype=dtype)
464467

465-
if self.dtype in (Rational, SecretRational):
468+
if self.is_rational:
466469
nada_type = rational
467470
else:
468471
nada_type = Integer
@@ -483,7 +486,7 @@ def mean(self, axis=None, dtype=None, out=None) -> Any:
483486
return mean_arr
484487

485488
@staticmethod
486-
def output_array(array: Any, party: Party, prefix: str) -> list:
489+
def _output_array(array: Any, party: Party, prefix: str) -> list:
487490
"""
488491
Generate a list of Output objects for each element in the input array.
489492
@@ -512,7 +515,7 @@ def output_array(array: Any, party: Party, prefix: str) -> list:
512515

513516
if len(array.shape) == 0:
514517
# For compatibility we're leaving this here.
515-
return NadaArray.output_array(array.item(), party, prefix)
518+
return NadaArray._output_array(array.item(), party, prefix)
516519
if len(array.shape) == 1:
517520
return [
518521
(
@@ -525,7 +528,7 @@ def output_array(array: Any, party: Party, prefix: str) -> list:
525528
return [
526529
v
527530
for i in range(array.shape[0])
528-
for v in NadaArray.output_array(array[i], party, f"{prefix}_{i}")
531+
for v in NadaArray._output_array(array[i], party, f"{prefix}_{i}")
529532
]
530533

531534
def output(self, party: Party, prefix: str) -> list:
@@ -539,10 +542,10 @@ def output(self, party: Party, prefix: str) -> list:
539542
Returns:
540543
list: A list of Output objects.
541544
"""
542-
return NadaArray.output_array(self.inner, party, prefix)
545+
return NadaArray._output_array(self.inner, party, prefix)
543546

544547
@staticmethod
545-
def create_list(
548+
def _create_list(
546549
dims: Sequence[int],
547550
party: Optional[Party],
548551
prefix: Optional[str],
@@ -564,7 +567,7 @@ def create_list(
564567
if len(dims) == 1:
565568
return [generator(f"{prefix}_{i}", party) for i in range(dims[0])]
566569
return [
567-
NadaArray.create_list(
570+
NadaArray._create_list(
568571
dims[1:],
569572
party,
570573
f"{prefix}_{i}",
@@ -621,7 +624,7 @@ def array(
621624
raise ValueError(f"Unsupported nada_type: {nada_type}")
622625

623626
return NadaArray(
624-
np.array(NadaArray.create_list(dims, party, prefix, generator))
627+
np.array(NadaArray._create_list(dims, party, prefix, generator))
625628
)
626629

627630
@staticmethod
@@ -658,7 +661,7 @@ def random(
658661
else:
659662
raise ValueError(f"Unsupported nada_type: {nada_type}")
660663

661-
return NadaArray(np.array(NadaArray.create_list(dims, None, None, generator)))
664+
return NadaArray(np.array(NadaArray._create_list(dims, None, None, generator)))
662665

663666
def __len__(self):
664667
"""
@@ -680,27 +683,58 @@ def empty(self) -> bool:
680683
return len(self.inner) == 0
681684

682685
@property
683-
def dtype(self) -> Type:
686+
def dtype(
687+
self,
688+
) -> Optional[Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]]:
684689
"""
685-
Gets inner data type of NadaArray values.
690+
Gets data type of array.
686691
687692
Returns:
688-
Type: Inner data type.
693+
Optional[
694+
Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
695+
]: Array data type if applicable.
689696
"""
690-
# TODO: account for mixed typed NadaArrays due to e.g. padding
691-
if self.empty:
692-
return NoneType
693-
return type(self.inner.item(0))
697+
return get_dtype(self.inner)
694698

695699
@property
696700
def is_rational(self) -> bool:
697701
"""
698-
Returns whether or not the Array's type is a rational.
702+
Returns whether or not the Array type contains rationals.
703+
704+
Returns:
705+
bool: Boolean output.
706+
"""
707+
return self.dtype == NadaRational
708+
709+
@property
710+
def is_integer(self) -> bool:
711+
"""
712+
Returns whether or not the Array type contains signed integers.
713+
714+
Returns:
715+
bool: Boolean output.
716+
"""
717+
return self.dtype == NadaInteger
718+
719+
@property
720+
def is_unsigned_integer(self) -> bool:
721+
"""
722+
Returns whether or not the Array type contains unsigned integers.
723+
724+
Returns:
725+
bool: Boolean output.
726+
"""
727+
return self.dtype == NadaUnsignedInteger
728+
729+
@property
730+
def is_boolean(self) -> bool:
731+
"""
732+
Returns whether or not the Array type contains signed integers.
699733
700734
Returns:
701735
bool: Boolean output.
702736
"""
703-
return self.dtype in (Rational, SecretRational)
737+
return self.dtype == NadaBoolean
704738

705739
def __str__(self) -> str:
706740
"""
@@ -817,9 +851,23 @@ def item(self, *args, **kwargs):
817851
return NadaArray(result)
818852
return result
819853

854+
@overload
855+
def itemset(self, value: Any): ...
856+
@overload
857+
def itemset(self, item: Any, value: Any): ...
858+
820859
# pylint:disable=missing-function-docstring
821860
@copy_metadata(np.ndarray.itemset)
822861
def itemset(self, *args, **kwargs):
862+
value = None
863+
if len(args) == 1:
864+
value = args[0]
865+
elif len(args) == 2:
866+
value = args[1]
867+
else:
868+
value = kwargs["value"]
869+
870+
_check_type_compatibility(value, self.dtype)
823871
result = self.inner.itemset(*args, **kwargs)
824872
if isinstance(result, np.ndarray):
825873
return NadaArray(result)
@@ -835,9 +883,12 @@ def prod(self, *args, **kwargs):
835883

836884
# pylint:disable=missing-function-docstring
837885
@copy_metadata(np.ndarray.put)
838-
def put(self, *args, **kwargs):
839-
result = self.inner.put(*args, **kwargs)
840-
return result
886+
def put(self, ind: Any, v: Any, mode: Any = None) -> None:
887+
_check_type_compatibility(v, self.dtype)
888+
if isinstance(v, NadaArray):
889+
self.inner.put(ind, v.inner, mode)
890+
else:
891+
self.inner.put(ind, v, mode)
841892

842893
# pylint:disable=missing-function-docstring
843894
@copy_metadata(np.ndarray.ravel)
@@ -1007,3 +1058,76 @@ def T(self): # pylint:disable=invalid-name
10071058
if isinstance(result, np.ndarray):
10081059
return NadaArray(result)
10091060
return result
1061+
1062+
1063+
def _check_type_compatibility(
1064+
value: Any,
1065+
check_type: Optional[
1066+
Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
1067+
],
1068+
) -> None:
1069+
"""
1070+
Checks type compatibility between a type and a Nada base type.
1071+
1072+
Args:
1073+
value (Any): Value to be type-checked.
1074+
check_type (Optional[
1075+
Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
1076+
]): Base Nada type to check against.
1077+
1078+
Raises:
1079+
TypeError: Raised when types are not compatible.
1080+
"""
1081+
if isinstance(value, (NadaArray, np.ndarray)):
1082+
if isinstance(value, NadaArray):
1083+
value = value.inner
1084+
dtype = get_dtype(value)
1085+
if dtype is None or check_type is None:
1086+
raise TypeError(f"Type {dtype} is not compatible with {check_type}")
1087+
if dtype == check_type:
1088+
return
1089+
else:
1090+
dtype = type(value)
1091+
if dtype in get_args(check_type):
1092+
return
1093+
raise TypeError(f"Type {dtype} is not compatible with {check_type}")
1094+
1095+
1096+
def _check_type_conflicts(array: np.ndarray) -> None:
1097+
"""
1098+
Checks for type conflicts
1099+
1100+
Args:
1101+
array (np.ndarray): Array to be checked.
1102+
1103+
Raises:
1104+
TypeError: Raised when incompatible dtypes are detected.
1105+
"""
1106+
_ = get_dtype(array)
1107+
1108+
1109+
def get_dtype(
1110+
array: np.ndarray,
1111+
) -> Optional[Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]]:
1112+
"""
1113+
Gets all data types present in array.
1114+
1115+
Args:
1116+
array (np.ndarray): Array to be checked.
1117+
1118+
Raises:
1119+
TypeError: Raised when incompatible dtypes are detected.
1120+
1121+
Returns:
1122+
Optional[Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]: Array dtype].
1123+
"""
1124+
if array.size == 0:
1125+
return None
1126+
1127+
unique_types = set(type(element) for element in array.flat)
1128+
1129+
base_dtypes = [NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
1130+
for base_dtype in base_dtypes:
1131+
if all(unique_type in get_args(base_dtype) for unique_type in unique_types):
1132+
return base_dtype
1133+
raise TypeError(f"Nada-incompatible dtypes detected in `{unique_types}`.")

nada_algebra/funcs.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from typing import Any, Callable, List, Sequence, Tuple, Union
77

88
import numpy as np
9-
from nada_dsl import (Integer, Party, PublicInteger, PublicUnsignedInteger,
10-
SecretInteger, SecretUnsignedInteger, UnsignedInteger)
9+
from nada_dsl import (Boolean, Integer, Party, PublicInteger,
10+
PublicUnsignedInteger, SecretInteger,
11+
SecretUnsignedInteger, UnsignedInteger)
1112

1213
from nada_algebra.array import NadaArray
1314
from nada_algebra.types import Rational, SecretRational, rational
@@ -229,7 +230,8 @@ def output(arr: NadaArray, party: Party, prefix: str):
229230
Returns:
230231
list: A list of Output objects.
231232
"""
232-
return NadaArray.output_array(arr, party, prefix)
233+
# pylint:disable=protected-access
234+
return NadaArray._output_array(arr, party, prefix)
233235

234236

235237
def vstack(arr_list: list) -> NadaArray:
@@ -344,19 +346,18 @@ def pad(
344346

345347
# Override python defaults by NadaType defaults
346348
overriden_kwargs = {}
347-
if mode == "constant":
348-
dtype = arr.dtype
349-
if dtype in (Rational, SecretRational):
350-
nada_type = rational
351-
elif dtype in (PublicInteger, SecretInteger):
352-
nada_type = Integer
353-
elif dtype == (PublicUnsignedInteger, SecretUnsignedInteger):
354-
nada_type = UnsignedInteger
349+
if mode == "constant" and "constant_values" not in kwargs:
350+
if arr.is_rational:
351+
default = rational(0)
352+
elif arr.is_integer:
353+
default = Integer(0)
354+
elif arr.is_unsigned_integer:
355+
default = UnsignedInteger(0)
355356
else:
356-
nada_type = dtype
357+
default = Boolean(False)
357358

358359
overriden_kwargs["constant_values"] = kwargs.get(
359-
"constant_values", nada_type(0)
360+
"constant_values", default
360361
)
361362

362363
padded_inner = np.pad( # type: ignore

nada_algebra/nada_typing.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Contains custom typing traits"""
2+
3+
from typing import Union
4+
5+
import nada_dsl as dsl
6+
7+
from nada_algebra.types import (PublicBoolean, Rational, SecretBoolean,
8+
SecretRational)
9+
10+
NadaRational = Union[
11+
Rational,
12+
SecretRational,
13+
]
14+
15+
NadaInteger = Union[
16+
dsl.Integer,
17+
dsl.PublicInteger,
18+
dsl.SecretInteger,
19+
]
20+
21+
NadaUnsignedInteger = Union[
22+
dsl.UnsignedInteger,
23+
dsl.PublicUnsignedInteger,
24+
dsl.SecretUnsignedInteger,
25+
]
26+
27+
NadaBoolean = Union[
28+
dsl.Boolean,
29+
dsl.PublicBoolean,
30+
dsl.SecretBoolean,
31+
PublicBoolean,
32+
SecretBoolean,
33+
]

0 commit comments

Comments
 (0)