@@ -83,8 +83,8 @@ def add(self, other: Any) -> "NadaArray":
83
83
NadaArray: A new NadaArray representing the element-wise addition result.
84
84
"""
85
85
if isinstance (other , NadaArray ):
86
- return NadaArray ( self . inner + other .inner )
87
- return NadaArray (self .inner + other )
86
+ other = other .inner
87
+ return NadaArray (np . array ( self .inner + other ) )
88
88
89
89
def __add__ (self , other : Any ) -> "NadaArray" :
90
90
"""
@@ -121,8 +121,8 @@ def sub(self, other: Any) -> "NadaArray":
121
121
NadaArray: A new NadaArray representing the element-wise subtraction result.
122
122
"""
123
123
if isinstance (other , NadaArray ):
124
- return NadaArray ( self . inner - other .inner )
125
- return NadaArray (self .inner - other )
124
+ other = other .inner
125
+ return NadaArray (np . array ( self .inner - other ) )
126
126
127
127
def __sub__ (self , other : Any ) -> "NadaArray" :
128
128
"""
@@ -159,8 +159,8 @@ def mul(self, other: Any) -> "NadaArray":
159
159
NadaArray: A new NadaArray representing the element-wise multiplication result.
160
160
"""
161
161
if isinstance (other , NadaArray ):
162
- return NadaArray ( self . inner * other .inner )
163
- return NadaArray (self .inner * other )
162
+ other = other .inner
163
+ return NadaArray (np . array ( self .inner * other ) )
164
164
165
165
def __mul__ (self , other : Any ) -> "NadaArray" :
166
166
"""
@@ -220,8 +220,8 @@ def divide(self, other: Any) -> "NadaArray":
220
220
NadaArray: A new NadaArray representing the element-wise division result.
221
221
"""
222
222
if isinstance (other , NadaArray ):
223
- return NadaArray ( self . inner / other .inner )
224
- return NadaArray (self .inner / other )
223
+ other = other .inner
224
+ return NadaArray (np . array ( self .inner / other ) )
225
225
226
226
def __truediv__ (self , other : Any ) -> "NadaArray" :
227
227
"""
@@ -316,7 +316,7 @@ def dot(self, other: "NadaArray") -> "NadaArray":
316
316
if self .is_rational or other .is_rational :
317
317
return self .rational_matmul (other )
318
318
319
- return NadaArray (self .inner .dot (other .inner ))
319
+ return NadaArray (np . array ( self .inner .dot (other .inner ) ))
320
320
321
321
def hstack (self , other : "NadaArray" ) -> "NadaArray" :
322
322
"""
@@ -361,7 +361,7 @@ def apply(self, func: Callable[[Any], Any]) -> "NadaArray":
361
361
Returns:
362
362
NadaArray: A new NadaArray with the function applied to each element.
363
363
"""
364
- return NadaArray (np .frompyfunc (func , 1 , 1 )(self .inner ))
364
+ return NadaArray (np .array ( np . frompyfunc (func , 1 , 1 )(self .inner ) ))
365
365
366
366
@copy_metadata (np .ndarray .mean )
367
367
def mean (self , axis = None , dtype = None , out = None , keepdims = False ) -> Any :
@@ -419,7 +419,10 @@ def output_array(array: np.ndarray, party: Party, prefix: str) -> list:
419
419
elif isinstance (array , (Rational , SecretRational )):
420
420
return [Output (array .value , f"{ prefix } _0" , party )]
421
421
422
- if len (array .shape ) == 1 :
422
+ if len (array .shape ) == 0 :
423
+ return NadaArray .output_array (array .item (), party , prefix )
424
+
425
+ elif len (array .shape ) == 1 :
423
426
return [
424
427
(
425
428
Output (array [i ].value , f"{ prefix } _{ i } " , party )
@@ -598,241 +601,175 @@ def is_rational(self) -> bool:
598
601
@copy_metadata (np .ndarray .compress )
599
602
def compress (self , * args , ** kwargs ):
600
603
result = self .inner .compress (* args , ** kwargs )
601
- if isinstance (result , np .ndarray ):
602
- result = NadaArray (result )
603
- return result
604
+ return NadaArray (np .array (result ))
604
605
605
606
@copy_metadata (np .ndarray .copy )
606
607
def copy (self , * args , ** kwargs ):
607
608
result = self .inner .copy (* args , ** kwargs )
608
- if isinstance (result , np .ndarray ):
609
- result = NadaArray (result )
610
- return result
609
+ return NadaArray (np .array (result ))
611
610
612
611
@copy_metadata (np .ndarray .cumprod )
613
612
def cumprod (self , * args , ** kwargs ):
614
613
result = self .inner .cumprod (* args , ** kwargs )
615
- if isinstance (result , np .ndarray ):
616
- result = NadaArray (result )
617
- return result
614
+ return NadaArray (np .array (result ))
618
615
619
616
@copy_metadata (np .ndarray .cumsum )
620
617
def cumsum (self , * args , ** kwargs ):
621
618
result = self .inner .cumsum (* args , ** kwargs )
622
- if isinstance (result , np .ndarray ):
623
- result = NadaArray (result )
624
- return result
619
+ return NadaArray (np .array (result ))
625
620
626
621
@copy_metadata (np .ndarray .diagonal )
627
622
def diagonal (self , * args , ** kwargs ):
628
623
result = self .inner .diagonal (* args , ** kwargs )
629
- if isinstance (result , np .ndarray ):
630
- result = NadaArray (result )
631
- return result
624
+ return NadaArray (np .array (result ))
632
625
633
626
@copy_metadata (np .ndarray .fill )
634
627
def fill (self , * args , ** kwargs ):
635
628
result = self .inner .fill (* args , ** kwargs )
636
- if isinstance (result , np .ndarray ):
637
- result = NadaArray (result )
638
629
return result
639
630
640
631
@copy_metadata (np .ndarray .flatten )
641
632
def flatten (self , * args , ** kwargs ):
642
633
result = self .inner .flatten (* args , ** kwargs )
643
- if isinstance (result , np .ndarray ):
644
- result = NadaArray (result )
645
- return result
634
+ return NadaArray (np .array (result ))
646
635
647
636
@copy_metadata (np .ndarray .item )
648
637
def item (self , * args , ** kwargs ):
649
638
result = self .inner .item (* args , ** kwargs )
650
- if isinstance (result , np .ndarray ):
651
- result = NadaArray (result )
652
639
return result
653
640
654
641
@copy_metadata (np .ndarray .itemset )
655
642
def itemset (self , * args , ** kwargs ):
656
643
result = self .inner .itemset (* args , ** kwargs )
657
- if isinstance (result , np .ndarray ):
658
- result = NadaArray (result )
659
644
return result
660
645
661
646
@copy_metadata (np .ndarray .prod )
662
647
def prod (self , * args , ** kwargs ):
663
648
result = self .inner .prod (* args , ** kwargs )
664
- if isinstance (result , np .ndarray ):
665
- result = NadaArray (result )
666
- return result
649
+ return NadaArray (np .array (result ))
667
650
668
651
@copy_metadata (np .ndarray .put )
669
652
def put (self , * args , ** kwargs ):
670
653
result = self .inner .put (* args , ** kwargs )
671
- if isinstance (result , np .ndarray ):
672
- result = NadaArray (result )
673
654
return result
674
655
675
656
@copy_metadata (np .ndarray .ravel )
676
657
def ravel (self , * args , ** kwargs ):
677
658
result = self .inner .ravel (* args , ** kwargs )
678
- if isinstance (result , np .ndarray ):
679
- result = NadaArray (result )
680
- return result
659
+ return NadaArray (np .array (result ))
681
660
682
661
@copy_metadata (np .ndarray .repeat )
683
662
def repeat (self , * args , ** kwargs ):
684
663
result = self .inner .repeat (* args , ** kwargs )
685
- if isinstance (result , np .ndarray ):
686
- result = NadaArray (result )
687
- return result
664
+ return NadaArray (np .array (result ))
688
665
689
666
@copy_metadata (np .ndarray .reshape )
690
667
def reshape (self , * args , ** kwargs ):
691
668
result = self .inner .reshape (* args , ** kwargs )
692
- if isinstance (result , np .ndarray ):
693
- result = NadaArray (result )
694
- return result
669
+ return NadaArray (np .array (result ))
695
670
696
671
@copy_metadata (np .ndarray .resize )
697
672
def resize (self , * args , ** kwargs ):
698
673
result = self .inner .resize (* args , ** kwargs )
699
- if isinstance (result , np .ndarray ):
700
- result = NadaArray (result )
701
- return result
674
+ return NadaArray (np .array (result ))
702
675
703
676
@copy_metadata (np .ndarray .squeeze )
704
677
def squeeze (self , * args , ** kwargs ):
705
678
result = self .inner .squeeze (* args , ** kwargs )
706
- if isinstance (result , np .ndarray ):
707
- result = NadaArray (result )
708
- return result
679
+ return NadaArray (np .array (result ))
709
680
710
681
@copy_metadata (np .ndarray .sum )
711
682
def sum (self , * args , ** kwargs ):
712
683
result = self .inner .sum (* args , ** kwargs )
713
- if isinstance (result , np .ndarray ):
714
- result = NadaArray (result )
715
- return result
684
+ return NadaArray (np .array (result ))
716
685
717
686
@copy_metadata (np .ndarray .swapaxes )
718
687
def swapaxes (self , * args , ** kwargs ):
719
688
result = self .inner .swapaxes (* args , ** kwargs )
720
- if isinstance (result , np .ndarray ):
721
- result = NadaArray (result )
722
- return result
689
+ return NadaArray (np .array (result ))
723
690
724
691
@copy_metadata (np .ndarray .take )
725
692
def take (self , * args , ** kwargs ):
726
693
result = self .inner .take (* args , ** kwargs )
727
- if isinstance (result , np .ndarray ):
728
- result = NadaArray (result )
729
- return result
694
+ return NadaArray (np .array (result ))
730
695
731
696
@copy_metadata (np .ndarray .tolist )
732
697
def tolist (self , * args , ** kwargs ):
733
698
result = self .inner .tolist (* args , ** kwargs )
734
- if isinstance (result , np .ndarray ):
735
- result = NadaArray (result )
736
699
return result
737
700
738
701
@copy_metadata (np .ndarray .trace )
739
702
def trace (self , * args , ** kwargs ):
740
703
result = self .inner .trace (* args , ** kwargs )
741
- if isinstance (result , np .ndarray ):
742
- result = NadaArray (result )
743
- return result
704
+ return NadaArray (np .array (result ))
744
705
745
706
@copy_metadata (np .ndarray .transpose )
746
707
def transpose (self , * args , ** kwargs ):
747
708
result = self .inner .transpose (* args , ** kwargs )
748
- if isinstance (result , np .ndarray ):
749
- result = NadaArray (result )
750
- return result
709
+ return NadaArray (np .array (result ))
751
710
752
711
@property
753
712
@copy_metadata (np .ndarray .base )
754
713
def base (self ):
755
714
result = self .inner .base
756
- if isinstance (result , np .ndarray ):
757
- result = NadaArray (result )
758
- return result
715
+ return NadaArray (np .array (result ))
759
716
760
717
@property
761
718
@copy_metadata (np .ndarray .data )
762
719
def data (self ):
763
720
result = self .inner .data
764
- if isinstance (result , np .ndarray ):
765
- result = NadaArray (result )
766
721
return result
767
722
768
723
@property
769
724
@copy_metadata (np .ndarray .flags )
770
725
def flags (self ):
771
726
result = self .inner .flags
772
- if isinstance (result , np .ndarray ):
773
- result = NadaArray (result )
774
727
return result
775
728
776
729
@property
777
730
@copy_metadata (np .ndarray .flat )
778
731
def flat (self ):
779
732
result = self .inner .flat
780
- if isinstance (result , np .ndarray ):
781
- result = NadaArray (result )
782
- return result
733
+ return NadaArray (np .array (result ))
783
734
784
735
@property
785
736
@copy_metadata (np .ndarray .itemsize )
786
737
def itemsize (self ):
787
738
result = self .inner .itemsize
788
- if isinstance (result , np .ndarray ):
789
- result = NadaArray (result )
790
739
return result
791
740
792
741
@property
793
742
@copy_metadata (np .ndarray .nbytes )
794
743
def nbytes (self ):
795
744
result = self .inner .nbytes
796
- if isinstance (result , np .ndarray ):
797
- result = NadaArray (result )
798
745
return result
799
746
800
747
@property
801
748
@copy_metadata (np .ndarray .ndim )
802
749
def ndim (self ):
803
750
result = self .inner .ndim
804
- if isinstance (result , np .ndarray ):
805
- result = NadaArray (result )
806
751
return result
807
752
808
753
@property
809
754
@copy_metadata (np .ndarray .shape )
810
755
def shape (self ):
811
756
result = self .inner .shape
812
- if isinstance (result , np .ndarray ):
813
- result = NadaArray (result )
814
757
return result
815
758
816
759
@property
817
760
@copy_metadata (np .ndarray .size )
818
761
def size (self ):
819
762
result = self .inner .size
820
- if isinstance (result , np .ndarray ):
821
- result = NadaArray (result )
822
763
return result
823
764
824
765
@property
825
766
@copy_metadata (np .ndarray .strides )
826
767
def strides (self ):
827
768
result = self .inner .strides
828
- if isinstance (result , np .ndarray ):
829
- result = NadaArray (result )
830
769
return result
831
770
832
771
@property
833
772
@copy_metadata (np .ndarray .T )
834
773
def T (self ):
835
774
result = self .inner .T
836
- if isinstance (result , np .ndarray ):
837
- result = NadaArray (result )
838
- return result
775
+ return NadaArray (np .array (result ))
0 commit comments