5
5
6
6
# pylint:disable=too-many-lines
7
7
8
- from types import NoneType
9
- from typing import Any , Callable , Optional , Sequence , Type , Union
8
+ from typing import Any , Callable , Optional , Sequence , Union , get_args , overload
10
9
11
10
import numpy as np
12
11
from nada_dsl import (Input , Integer , Output , Party , PublicInteger ,
13
12
PublicUnsignedInteger , SecretInteger ,
14
13
SecretUnsignedInteger , UnsignedInteger )
15
14
16
15
from nada_algebra .context import UnsafeArithmeticSession
16
+ from nada_algebra .nada_typing import (NadaBoolean , NadaInteger , NadaRational ,
17
+ NadaUnsignedInteger )
17
18
from nada_algebra .types import (Rational , SecretRational , get_log_scale ,
18
19
public_rational , rational , secret_rational )
19
20
from nada_algebra .utils import copy_metadata
22
23
class NadaArray : # pylint:disable=too-many-public-methods
23
24
"""
24
25
Represents an array-like object with additional functionality.
25
-
26
- Attributes:
27
- inner (np.ndarray): The underlying NumPy array.
28
26
"""
29
27
30
28
def __init__ (self , inner : np .ndarray ):
@@ -41,6 +39,7 @@ def __init__(self, inner: np.ndarray):
41
39
raise ValueError (f"inner must be a numpy array and is: { type (inner )} " )
42
40
if isinstance (inner , NadaArray ):
43
41
inner = inner .inner
42
+ _check_type_conflicts (inner )
44
43
self .inner = inner
45
44
46
45
def __getitem__ (self , item ):
@@ -65,7 +64,11 @@ def __setitem__(self, key, value):
65
64
Args:
66
65
key: The key to set.
67
66
value: The value to set.
67
+
68
+ Raises:
69
+ ValueError: Raised when value with incompatible type is passed.
68
70
"""
71
+ _check_type_compatibility (value , self .dtype )
69
72
if isinstance (value , NadaArray ):
70
73
self .inner [key ] = value .inner
71
74
else :
@@ -462,7 +465,7 @@ def apply(self, func: Callable[[Any], Any]) -> "NadaArray":
462
465
def mean (self , axis = None , dtype = None , out = None ) -> Any :
463
466
sum_arr = self .inner .sum (axis = axis , dtype = dtype )
464
467
465
- if self .dtype in ( Rational , SecretRational ) :
468
+ if self .is_rational :
466
469
nada_type = rational
467
470
else :
468
471
nada_type = Integer
@@ -483,7 +486,7 @@ def mean(self, axis=None, dtype=None, out=None) -> Any:
483
486
return mean_arr
484
487
485
488
@staticmethod
486
- def output_array (array : Any , party : Party , prefix : str ) -> list :
489
+ def _output_array (array : Any , party : Party , prefix : str ) -> list :
487
490
"""
488
491
Generate a list of Output objects for each element in the input array.
489
492
@@ -512,7 +515,7 @@ def output_array(array: Any, party: Party, prefix: str) -> list:
512
515
513
516
if len (array .shape ) == 0 :
514
517
# For compatibility we're leaving this here.
515
- return NadaArray .output_array (array .item (), party , prefix )
518
+ return NadaArray ._output_array (array .item (), party , prefix )
516
519
if len (array .shape ) == 1 :
517
520
return [
518
521
(
@@ -525,7 +528,7 @@ def output_array(array: Any, party: Party, prefix: str) -> list:
525
528
return [
526
529
v
527
530
for i in range (array .shape [0 ])
528
- for v in NadaArray .output_array (array [i ], party , f"{ prefix } _{ i } " )
531
+ for v in NadaArray ._output_array (array [i ], party , f"{ prefix } _{ i } " )
529
532
]
530
533
531
534
def output (self , party : Party , prefix : str ) -> list :
@@ -539,10 +542,10 @@ def output(self, party: Party, prefix: str) -> list:
539
542
Returns:
540
543
list: A list of Output objects.
541
544
"""
542
- return NadaArray .output_array (self .inner , party , prefix )
545
+ return NadaArray ._output_array (self .inner , party , prefix )
543
546
544
547
@staticmethod
545
- def create_list (
548
+ def _create_list (
546
549
dims : Sequence [int ],
547
550
party : Optional [Party ],
548
551
prefix : Optional [str ],
@@ -564,7 +567,7 @@ def create_list(
564
567
if len (dims ) == 1 :
565
568
return [generator (f"{ prefix } _{ i } " , party ) for i in range (dims [0 ])]
566
569
return [
567
- NadaArray .create_list (
570
+ NadaArray ._create_list (
568
571
dims [1 :],
569
572
party ,
570
573
f"{ prefix } _{ i } " ,
@@ -621,7 +624,7 @@ def array(
621
624
raise ValueError (f"Unsupported nada_type: { nada_type } " )
622
625
623
626
return NadaArray (
624
- np .array (NadaArray .create_list (dims , party , prefix , generator ))
627
+ np .array (NadaArray ._create_list (dims , party , prefix , generator ))
625
628
)
626
629
627
630
@staticmethod
@@ -658,7 +661,7 @@ def random(
658
661
else :
659
662
raise ValueError (f"Unsupported nada_type: { nada_type } " )
660
663
661
- return NadaArray (np .array (NadaArray .create_list (dims , None , None , generator )))
664
+ return NadaArray (np .array (NadaArray ._create_list (dims , None , None , generator )))
662
665
663
666
def __len__ (self ):
664
667
"""
@@ -680,27 +683,58 @@ def empty(self) -> bool:
680
683
return len (self .inner ) == 0
681
684
682
685
@property
683
- def dtype (self ) -> Type :
686
+ def dtype (
687
+ self ,
688
+ ) -> Optional [Union [NadaRational , NadaInteger , NadaUnsignedInteger , NadaBoolean ]]:
684
689
"""
685
- Gets inner data type of NadaArray values .
690
+ Gets data type of array .
686
691
687
692
Returns:
688
- Type: Inner data type.
693
+ Optional[
694
+ Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
695
+ ]: Array data type if applicable.
689
696
"""
690
- # TODO: account for mixed typed NadaArrays due to e.g. padding
691
- if self .empty :
692
- return NoneType
693
- return type (self .inner .item (0 ))
697
+ return get_dtype (self .inner )
694
698
695
699
@property
696
700
def is_rational (self ) -> bool :
697
701
"""
698
- Returns whether or not the Array's type is a rational.
702
+ Returns whether or not the Array type contains rationals.
703
+
704
+ Returns:
705
+ bool: Boolean output.
706
+ """
707
+ return self .dtype == NadaRational
708
+
709
+ @property
710
+ def is_integer (self ) -> bool :
711
+ """
712
+ Returns whether or not the Array type contains signed integers.
713
+
714
+ Returns:
715
+ bool: Boolean output.
716
+ """
717
+ return self .dtype == NadaInteger
718
+
719
+ @property
720
+ def is_unsigned_integer (self ) -> bool :
721
+ """
722
+ Returns whether or not the Array type contains unsigned integers.
723
+
724
+ Returns:
725
+ bool: Boolean output.
726
+ """
727
+ return self .dtype == NadaUnsignedInteger
728
+
729
+ @property
730
+ def is_boolean (self ) -> bool :
731
+ """
732
+ Returns whether or not the Array type contains signed integers.
699
733
700
734
Returns:
701
735
bool: Boolean output.
702
736
"""
703
- return self .dtype in ( Rational , SecretRational )
737
+ return self .dtype == NadaBoolean
704
738
705
739
def __str__ (self ) -> str :
706
740
"""
@@ -817,9 +851,23 @@ def item(self, *args, **kwargs):
817
851
return NadaArray (result )
818
852
return result
819
853
854
+ @overload
855
+ def itemset (self , value : Any ): ...
856
+ @overload
857
+ def itemset (self , item : Any , value : Any ): ...
858
+
820
859
# pylint:disable=missing-function-docstring
821
860
@copy_metadata (np .ndarray .itemset )
822
861
def itemset (self , * args , ** kwargs ):
862
+ value = None
863
+ if len (args ) == 1 :
864
+ value = args [0 ]
865
+ elif len (args ) == 2 :
866
+ value = args [1 ]
867
+ else :
868
+ value = kwargs ["value" ]
869
+
870
+ _check_type_compatibility (value , self .dtype )
823
871
result = self .inner .itemset (* args , ** kwargs )
824
872
if isinstance (result , np .ndarray ):
825
873
return NadaArray (result )
@@ -835,9 +883,12 @@ def prod(self, *args, **kwargs):
835
883
836
884
# pylint:disable=missing-function-docstring
837
885
@copy_metadata (np .ndarray .put )
838
- def put (self , * args , ** kwargs ):
839
- result = self .inner .put (* args , ** kwargs )
840
- return result
886
+ def put (self , ind : Any , v : Any , mode : Any = None ) -> None :
887
+ _check_type_compatibility (v , self .dtype )
888
+ if isinstance (v , NadaArray ):
889
+ self .inner .put (ind , v .inner , mode )
890
+ else :
891
+ self .inner .put (ind , v , mode )
841
892
842
893
# pylint:disable=missing-function-docstring
843
894
@copy_metadata (np .ndarray .ravel )
@@ -1007,3 +1058,76 @@ def T(self): # pylint:disable=invalid-name
1007
1058
if isinstance (result , np .ndarray ):
1008
1059
return NadaArray (result )
1009
1060
return result
1061
+
1062
+
1063
+ def _check_type_compatibility (
1064
+ value : Any ,
1065
+ check_type : Optional [
1066
+ Union [NadaRational , NadaInteger , NadaUnsignedInteger , NadaBoolean ]
1067
+ ],
1068
+ ) -> None :
1069
+ """
1070
+ Checks type compatibility between a type and a Nada base type.
1071
+
1072
+ Args:
1073
+ value (Any): Value to be type-checked.
1074
+ check_type (Optional[
1075
+ Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]
1076
+ ]): Base Nada type to check against.
1077
+
1078
+ Raises:
1079
+ TypeError: Raised when types are not compatible.
1080
+ """
1081
+ if isinstance (value , (NadaArray , np .ndarray )):
1082
+ if isinstance (value , NadaArray ):
1083
+ value = value .inner
1084
+ dtype = get_dtype (value )
1085
+ if dtype is None or check_type is None :
1086
+ raise TypeError (f"Type { dtype } is not compatible with { check_type } " )
1087
+ if dtype == check_type :
1088
+ return
1089
+ else :
1090
+ dtype = type (value )
1091
+ if dtype in get_args (check_type ):
1092
+ return
1093
+ raise TypeError (f"Type { dtype } is not compatible with { check_type } " )
1094
+
1095
+
1096
+ def _check_type_conflicts (array : np .ndarray ) -> None :
1097
+ """
1098
+ Checks for type conflicts
1099
+
1100
+ Args:
1101
+ array (np.ndarray): Array to be checked.
1102
+
1103
+ Raises:
1104
+ TypeError: Raised when incompatible dtypes are detected.
1105
+ """
1106
+ _ = get_dtype (array )
1107
+
1108
+
1109
+ def get_dtype (
1110
+ array : np .ndarray ,
1111
+ ) -> Optional [Union [NadaRational , NadaInteger , NadaUnsignedInteger , NadaBoolean ]]:
1112
+ """
1113
+ Gets all data types present in array.
1114
+
1115
+ Args:
1116
+ array (np.ndarray): Array to be checked.
1117
+
1118
+ Raises:
1119
+ TypeError: Raised when incompatible dtypes are detected.
1120
+
1121
+ Returns:
1122
+ Optional[Union[NadaRational, NadaInteger, NadaUnsignedInteger, NadaBoolean]: Array dtype].
1123
+ """
1124
+ if array .size == 0 :
1125
+ return None
1126
+
1127
+ unique_types = set (type (element ) for element in array .flat )
1128
+
1129
+ base_dtypes = [NadaRational , NadaInteger , NadaUnsignedInteger , NadaBoolean ]
1130
+ for base_dtype in base_dtypes :
1131
+ if all (unique_type in get_args (base_dtype ) for unique_type in unique_types ):
1132
+ return base_dtype
1133
+ raise TypeError (f"Nada-incompatible dtypes detected in `{ unique_types } `." )
0 commit comments