Skip to content

Commit 5e913b4

Browse files
authored
fix: ensure return types for array class is always a nada array (#29)
* fix: ensure return types for array class is always a nada array * chore: added more tests to check return types * fix: function not being taken into consideration
1 parent 9791ee6 commit 5e913b4

10 files changed

+261
-123
lines changed

nada_algebra/array.py

Lines changed: 34 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def add(self, other: Any) -> "NadaArray":
8383
NadaArray: A new NadaArray representing the element-wise addition result.
8484
"""
8585
if isinstance(other, NadaArray):
86-
return NadaArray(self.inner + other.inner)
87-
return NadaArray(self.inner + other)
86+
other = other.inner
87+
return NadaArray(np.array(self.inner + other))
8888

8989
def __add__(self, other: Any) -> "NadaArray":
9090
"""
@@ -121,8 +121,8 @@ def sub(self, other: Any) -> "NadaArray":
121121
NadaArray: A new NadaArray representing the element-wise subtraction result.
122122
"""
123123
if isinstance(other, NadaArray):
124-
return NadaArray(self.inner - other.inner)
125-
return NadaArray(self.inner - other)
124+
other = other.inner
125+
return NadaArray(np.array(self.inner - other))
126126

127127
def __sub__(self, other: Any) -> "NadaArray":
128128
"""
@@ -159,8 +159,8 @@ def mul(self, other: Any) -> "NadaArray":
159159
NadaArray: A new NadaArray representing the element-wise multiplication result.
160160
"""
161161
if isinstance(other, NadaArray):
162-
return NadaArray(self.inner * other.inner)
163-
return NadaArray(self.inner * other)
162+
other = other.inner
163+
return NadaArray(np.array(self.inner * other))
164164

165165
def __mul__(self, other: Any) -> "NadaArray":
166166
"""
@@ -220,8 +220,8 @@ def divide(self, other: Any) -> "NadaArray":
220220
NadaArray: A new NadaArray representing the element-wise division result.
221221
"""
222222
if isinstance(other, NadaArray):
223-
return NadaArray(self.inner / other.inner)
224-
return NadaArray(self.inner / other)
223+
other = other.inner
224+
return NadaArray(np.array(self.inner / other))
225225

226226
def __truediv__(self, other: Any) -> "NadaArray":
227227
"""
@@ -316,7 +316,7 @@ def dot(self, other: "NadaArray") -> "NadaArray":
316316
if self.is_rational or other.is_rational:
317317
return self.rational_matmul(other)
318318

319-
return NadaArray(self.inner.dot(other.inner))
319+
return NadaArray(np.array(self.inner.dot(other.inner)))
320320

321321
def hstack(self, other: "NadaArray") -> "NadaArray":
322322
"""
@@ -361,7 +361,7 @@ def apply(self, func: Callable[[Any], Any]) -> "NadaArray":
361361
Returns:
362362
NadaArray: A new NadaArray with the function applied to each element.
363363
"""
364-
return NadaArray(np.frompyfunc(func, 1, 1)(self.inner))
364+
return NadaArray(np.array(np.frompyfunc(func, 1, 1)(self.inner)))
365365

366366
@copy_metadata(np.ndarray.mean)
367367
def mean(self, axis=None, dtype=None, out=None, keepdims=False) -> Any:
@@ -419,7 +419,10 @@ def output_array(array: np.ndarray, party: Party, prefix: str) -> list:
419419
elif isinstance(array, (Rational, SecretRational)):
420420
return [Output(array.value, f"{prefix}_0", party)]
421421

422-
if len(array.shape) == 1:
422+
if len(array.shape) == 0:
423+
return NadaArray.output_array(array.item(), party, prefix)
424+
425+
elif len(array.shape) == 1:
423426
return [
424427
(
425428
Output(array[i].value, f"{prefix}_{i}", party)
@@ -598,241 +601,175 @@ def is_rational(self) -> bool:
598601
@copy_metadata(np.ndarray.compress)
599602
def compress(self, *args, **kwargs):
600603
result = self.inner.compress(*args, **kwargs)
601-
if isinstance(result, np.ndarray):
602-
result = NadaArray(result)
603-
return result
604+
return NadaArray(np.array(result))
604605

605606
@copy_metadata(np.ndarray.copy)
606607
def copy(self, *args, **kwargs):
607608
result = self.inner.copy(*args, **kwargs)
608-
if isinstance(result, np.ndarray):
609-
result = NadaArray(result)
610-
return result
609+
return NadaArray(np.array(result))
611610

612611
@copy_metadata(np.ndarray.cumprod)
613612
def cumprod(self, *args, **kwargs):
614613
result = self.inner.cumprod(*args, **kwargs)
615-
if isinstance(result, np.ndarray):
616-
result = NadaArray(result)
617-
return result
614+
return NadaArray(np.array(result))
618615

619616
@copy_metadata(np.ndarray.cumsum)
620617
def cumsum(self, *args, **kwargs):
621618
result = self.inner.cumsum(*args, **kwargs)
622-
if isinstance(result, np.ndarray):
623-
result = NadaArray(result)
624-
return result
619+
return NadaArray(np.array(result))
625620

626621
@copy_metadata(np.ndarray.diagonal)
627622
def diagonal(self, *args, **kwargs):
628623
result = self.inner.diagonal(*args, **kwargs)
629-
if isinstance(result, np.ndarray):
630-
result = NadaArray(result)
631-
return result
624+
return NadaArray(np.array(result))
632625

633626
@copy_metadata(np.ndarray.fill)
634627
def fill(self, *args, **kwargs):
635628
result = self.inner.fill(*args, **kwargs)
636-
if isinstance(result, np.ndarray):
637-
result = NadaArray(result)
638629
return result
639630

640631
@copy_metadata(np.ndarray.flatten)
641632
def flatten(self, *args, **kwargs):
642633
result = self.inner.flatten(*args, **kwargs)
643-
if isinstance(result, np.ndarray):
644-
result = NadaArray(result)
645-
return result
634+
return NadaArray(np.array(result))
646635

647636
@copy_metadata(np.ndarray.item)
648637
def item(self, *args, **kwargs):
649638
result = self.inner.item(*args, **kwargs)
650-
if isinstance(result, np.ndarray):
651-
result = NadaArray(result)
652639
return result
653640

654641
@copy_metadata(np.ndarray.itemset)
655642
def itemset(self, *args, **kwargs):
656643
result = self.inner.itemset(*args, **kwargs)
657-
if isinstance(result, np.ndarray):
658-
result = NadaArray(result)
659644
return result
660645

661646
@copy_metadata(np.ndarray.prod)
662647
def prod(self, *args, **kwargs):
663648
result = self.inner.prod(*args, **kwargs)
664-
if isinstance(result, np.ndarray):
665-
result = NadaArray(result)
666-
return result
649+
return NadaArray(np.array(result))
667650

668651
@copy_metadata(np.ndarray.put)
669652
def put(self, *args, **kwargs):
670653
result = self.inner.put(*args, **kwargs)
671-
if isinstance(result, np.ndarray):
672-
result = NadaArray(result)
673654
return result
674655

675656
@copy_metadata(np.ndarray.ravel)
676657
def ravel(self, *args, **kwargs):
677658
result = self.inner.ravel(*args, **kwargs)
678-
if isinstance(result, np.ndarray):
679-
result = NadaArray(result)
680-
return result
659+
return NadaArray(np.array(result))
681660

682661
@copy_metadata(np.ndarray.repeat)
683662
def repeat(self, *args, **kwargs):
684663
result = self.inner.repeat(*args, **kwargs)
685-
if isinstance(result, np.ndarray):
686-
result = NadaArray(result)
687-
return result
664+
return NadaArray(np.array(result))
688665

689666
@copy_metadata(np.ndarray.reshape)
690667
def reshape(self, *args, **kwargs):
691668
result = self.inner.reshape(*args, **kwargs)
692-
if isinstance(result, np.ndarray):
693-
result = NadaArray(result)
694-
return result
669+
return NadaArray(np.array(result))
695670

696671
@copy_metadata(np.ndarray.resize)
697672
def resize(self, *args, **kwargs):
698673
result = self.inner.resize(*args, **kwargs)
699-
if isinstance(result, np.ndarray):
700-
result = NadaArray(result)
701-
return result
674+
return NadaArray(np.array(result))
702675

703676
@copy_metadata(np.ndarray.squeeze)
704677
def squeeze(self, *args, **kwargs):
705678
result = self.inner.squeeze(*args, **kwargs)
706-
if isinstance(result, np.ndarray):
707-
result = NadaArray(result)
708-
return result
679+
return NadaArray(np.array(result))
709680

710681
@copy_metadata(np.ndarray.sum)
711682
def sum(self, *args, **kwargs):
712683
result = self.inner.sum(*args, **kwargs)
713-
if isinstance(result, np.ndarray):
714-
result = NadaArray(result)
715-
return result
684+
return NadaArray(np.array(result))
716685

717686
@copy_metadata(np.ndarray.swapaxes)
718687
def swapaxes(self, *args, **kwargs):
719688
result = self.inner.swapaxes(*args, **kwargs)
720-
if isinstance(result, np.ndarray):
721-
result = NadaArray(result)
722-
return result
689+
return NadaArray(np.array(result))
723690

724691
@copy_metadata(np.ndarray.take)
725692
def take(self, *args, **kwargs):
726693
result = self.inner.take(*args, **kwargs)
727-
if isinstance(result, np.ndarray):
728-
result = NadaArray(result)
729-
return result
694+
return NadaArray(np.array(result))
730695

731696
@copy_metadata(np.ndarray.tolist)
732697
def tolist(self, *args, **kwargs):
733698
result = self.inner.tolist(*args, **kwargs)
734-
if isinstance(result, np.ndarray):
735-
result = NadaArray(result)
736699
return result
737700

738701
@copy_metadata(np.ndarray.trace)
739702
def trace(self, *args, **kwargs):
740703
result = self.inner.trace(*args, **kwargs)
741-
if isinstance(result, np.ndarray):
742-
result = NadaArray(result)
743-
return result
704+
return NadaArray(np.array(result))
744705

745706
@copy_metadata(np.ndarray.transpose)
746707
def transpose(self, *args, **kwargs):
747708
result = self.inner.transpose(*args, **kwargs)
748-
if isinstance(result, np.ndarray):
749-
result = NadaArray(result)
750-
return result
709+
return NadaArray(np.array(result))
751710

752711
@property
753712
@copy_metadata(np.ndarray.base)
754713
def base(self):
755714
result = self.inner.base
756-
if isinstance(result, np.ndarray):
757-
result = NadaArray(result)
758-
return result
715+
return NadaArray(np.array(result))
759716

760717
@property
761718
@copy_metadata(np.ndarray.data)
762719
def data(self):
763720
result = self.inner.data
764-
if isinstance(result, np.ndarray):
765-
result = NadaArray(result)
766721
return result
767722

768723
@property
769724
@copy_metadata(np.ndarray.flags)
770725
def flags(self):
771726
result = self.inner.flags
772-
if isinstance(result, np.ndarray):
773-
result = NadaArray(result)
774727
return result
775728

776729
@property
777730
@copy_metadata(np.ndarray.flat)
778731
def flat(self):
779732
result = self.inner.flat
780-
if isinstance(result, np.ndarray):
781-
result = NadaArray(result)
782-
return result
733+
return NadaArray(np.array(result))
783734

784735
@property
785736
@copy_metadata(np.ndarray.itemsize)
786737
def itemsize(self):
787738
result = self.inner.itemsize
788-
if isinstance(result, np.ndarray):
789-
result = NadaArray(result)
790739
return result
791740

792741
@property
793742
@copy_metadata(np.ndarray.nbytes)
794743
def nbytes(self):
795744
result = self.inner.nbytes
796-
if isinstance(result, np.ndarray):
797-
result = NadaArray(result)
798745
return result
799746

800747
@property
801748
@copy_metadata(np.ndarray.ndim)
802749
def ndim(self):
803750
result = self.inner.ndim
804-
if isinstance(result, np.ndarray):
805-
result = NadaArray(result)
806751
return result
807752

808753
@property
809754
@copy_metadata(np.ndarray.shape)
810755
def shape(self):
811756
result = self.inner.shape
812-
if isinstance(result, np.ndarray):
813-
result = NadaArray(result)
814757
return result
815758

816759
@property
817760
@copy_metadata(np.ndarray.size)
818761
def size(self):
819762
result = self.inner.size
820-
if isinstance(result, np.ndarray):
821-
result = NadaArray(result)
822763
return result
823764

824765
@property
825766
@copy_metadata(np.ndarray.strides)
826767
def strides(self):
827768
result = self.inner.strides
828-
if isinstance(result, np.ndarray):
829-
result = NadaArray(result)
830769
return result
831770

832771
@property
833772
@copy_metadata(np.ndarray.T)
834773
def T(self):
835774
result = self.inner.T
836-
if isinstance(result, np.ndarray):
837-
result = NadaArray(result)
838-
return result
775+
return NadaArray(np.array(result))

tests/nada-tests/nada-project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,8 @@ prime_size = 128
156156

157157
[[programs]]
158158
path = "src/dot_product_rational.py"
159+
prime_size = 128
160+
161+
[[programs]]
162+
path = "src/supported_operations_return_types.py"
159163
prime_size = 128

tests/nada-tests/src/array_statistics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ def nada_main():
2020
a_mean_arr = a.mean(axis=0)
2121
b_mean_arr = b.mean(axis=0)
2222

23-
output_1 = [
24-
Output(a_sum, "a_sum", parties[1]),
25-
Output(a_mean, "a_mean", parties[1]),
26-
Output(b_sum.value, "b_sum", parties[1]),
27-
Output(b_mean.value, "b_mean", parties[1]),
28-
]
23+
output_1 = (
24+
na.output(a_sum, parties[1], "a_sum")
25+
+ na.output(a_mean, parties[1], "a_mean")
26+
+ na.output(b_sum, parties[1], "b_sum")
27+
+ na.output(b_mean, parties[1], "b_mean")
28+
)
2929
output_2 = (
3030
a_sum_arr.output(parties[1], "a_sum_arr")
3131
+ b_sum_arr.output(parties[1], "b_sum_arr")

tests/nada-tests/src/sum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def nada_main():
99

1010
result = a.sum()
1111

12-
return [Output(result, "my_output_0", parties[1])]
12+
return result.output(parties[1], "my_output")

tests/nada-tests/src/supported_operations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def nada_main():
7-
parties = na.parties(1)
7+
parties = na.parties(2)
88

99
a = na.array([3, 3], parties[0], "A", SecretInteger)
1010

@@ -57,11 +57,11 @@ def nada_main():
5757
f.itemset(0, f.item(0) + Integer(2))
5858
f = f.tolist()[0]
5959

60-
return [
61-
Output(a, "out_0", parties[0]),
62-
Output(b, "out_1", parties[0]),
63-
Output(c, "out_2", parties[0]),
64-
Output(d, "out_3", parties[0]),
65-
Output(e, "out_4", parties[0]),
66-
Output(f, "out_5", parties[0]),
67-
]
60+
return (
61+
na.output(a, parties[1], "out_0")
62+
+ na.output(b, parties[1], "out_1")
63+
+ na.output(c, parties[1], "out_2")
64+
+ na.output(d, parties[1], "out_3")
65+
+ na.output(e, parties[1], "out_4")
66+
+ na.output(f, parties[1], "out_5")
67+
)

0 commit comments

Comments
 (0)