Skip to content

Commit 9567774

Browse files
Feature/numpy functions (#23)
Add support for common NumPy functions
1 parent 0d6831e commit 9567774

File tree

4 files changed

+233
-34
lines changed

4 files changed

+233
-34
lines changed

nada_algebra/array.py

Lines changed: 133 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ def dtype(self) -> Type:
515515
Returns:
516516
Type: Inner data type.
517517
"""
518+
# TODO: account for mixed typed NadaArrays due to e.g. padding
518519
if self.empty:
519520
return NoneType
520521
return type(self.inner.item(0))
@@ -531,143 +532,242 @@ def is_rational(self) -> bool:
531532

532533
@copy_metadata(np.ndarray.compress)
533534
def compress(self, *args, **kwargs):
534-
return self.inner.compress(*args, **kwargs)
535+
result = self.inner.compress(*args, **kwargs)
536+
if isinstance(result, np.ndarray):
537+
result = NadaArray(result)
538+
return result
535539

536540
@copy_metadata(np.ndarray.copy)
537541
def copy(self, *args, **kwargs):
538-
return self.inner.copy(*args, **kwargs)
542+
result = self.inner.copy(*args, **kwargs)
543+
if isinstance(result, np.ndarray):
544+
result = NadaArray(result)
545+
return result
539546

540547
@copy_metadata(np.ndarray.cumprod)
541548
def cumprod(self, *args, **kwargs):
542-
return self.inner.cumprod(*args, **kwargs)
549+
result = self.inner.cumprod(*args, **kwargs)
550+
if isinstance(result, np.ndarray):
551+
result = NadaArray(result)
552+
return result
543553

544554
@copy_metadata(np.ndarray.cumsum)
545555
def cumsum(self, *args, **kwargs):
546-
return self.inner.cumsum(*args, **kwargs)
556+
result = self.inner.cumsum(*args, **kwargs)
557+
if isinstance(result, np.ndarray):
558+
result = NadaArray(result)
559+
return result
547560

548561
@copy_metadata(np.ndarray.diagonal)
549562
def diagonal(self, *args, **kwargs):
550-
return self.inner.diagonal(*args, **kwargs)
563+
result = self.inner.diagonal(*args, **kwargs)
564+
if isinstance(result, np.ndarray):
565+
result = NadaArray(result)
566+
return result
551567

552568
@copy_metadata(np.ndarray.fill)
553569
def fill(self, *args, **kwargs):
554-
return self.inner.fill(*args, **kwargs)
570+
result = self.inner.fill(*args, **kwargs)
571+
if isinstance(result, np.ndarray):
572+
result = NadaArray(result)
573+
return result
555574

556575
@copy_metadata(np.ndarray.flatten)
557576
def flatten(self, *args, **kwargs):
558-
return self.inner.flatten(*args, **kwargs)
577+
result = self.inner.flatten(*args, **kwargs)
578+
if isinstance(result, np.ndarray):
579+
result = NadaArray(result)
580+
return result
559581

560582
@copy_metadata(np.ndarray.item)
561583
def item(self, *args, **kwargs):
562-
return self.inner.item(*args, **kwargs)
584+
result = self.inner.item(*args, **kwargs)
585+
if isinstance(result, np.ndarray):
586+
result = NadaArray(result)
587+
return result
563588

564589
@copy_metadata(np.ndarray.itemset)
565590
def itemset(self, *args, **kwargs):
566-
return self.inner.itemset(*args, **kwargs)
591+
result = self.inner.itemset(*args, **kwargs)
592+
if isinstance(result, np.ndarray):
593+
result = NadaArray(result)
594+
return result
567595

568596
@copy_metadata(np.ndarray.prod)
569597
def prod(self, *args, **kwargs):
570-
return self.inner.prod(*args, **kwargs)
598+
result = self.inner.prod(*args, **kwargs)
599+
if isinstance(result, np.ndarray):
600+
result = NadaArray(result)
601+
return result
571602

572603
@copy_metadata(np.ndarray.put)
573604
def put(self, *args, **kwargs):
574-
return self.inner.put(*args, **kwargs)
605+
result = self.inner.put(*args, **kwargs)
606+
if isinstance(result, np.ndarray):
607+
result = NadaArray(result)
608+
return result
575609

576610
@copy_metadata(np.ndarray.ravel)
577611
def ravel(self, *args, **kwargs):
578-
return self.inner.ravel(*args, **kwargs)
612+
result = self.inner.ravel(*args, **kwargs)
613+
if isinstance(result, np.ndarray):
614+
result = NadaArray(result)
615+
return result
579616

580617
@copy_metadata(np.ndarray.repeat)
581618
def repeat(self, *args, **kwargs):
582-
return self.inner.repeat(*args, **kwargs)
619+
result = self.inner.repeat(*args, **kwargs)
620+
if isinstance(result, np.ndarray):
621+
result = NadaArray(result)
622+
return result
583623

584624
@copy_metadata(np.ndarray.reshape)
585625
def reshape(self, *args, **kwargs):
586-
return self.inner.reshape(*args, **kwargs)
626+
result = self.inner.reshape(*args, **kwargs)
627+
if isinstance(result, np.ndarray):
628+
result = NadaArray(result)
629+
return result
587630

588631
@copy_metadata(np.ndarray.resize)
589632
def resize(self, *args, **kwargs):
590-
return self.inner.resize(*args, **kwargs)
633+
result = self.inner.resize(*args, **kwargs)
634+
if isinstance(result, np.ndarray):
635+
result = NadaArray(result)
636+
return result
591637

592638
@copy_metadata(np.ndarray.squeeze)
593639
def squeeze(self, *args, **kwargs):
594-
return self.inner.squeeze(*args, **kwargs)
640+
result = self.inner.squeeze(*args, **kwargs)
641+
if isinstance(result, np.ndarray):
642+
result = NadaArray(result)
643+
return result
595644

596645
@copy_metadata(np.ndarray.sum)
597646
def sum(self, *args, **kwargs):
598-
return self.inner.sum(*args, **kwargs)
647+
result = self.inner.sum(*args, **kwargs)
648+
if isinstance(result, np.ndarray):
649+
result = NadaArray(result)
650+
return result
599651

600652
@copy_metadata(np.ndarray.swapaxes)
601653
def swapaxes(self, *args, **kwargs):
602-
return self.inner.swapaxes(*args, **kwargs)
654+
result = self.inner.swapaxes(*args, **kwargs)
655+
if isinstance(result, np.ndarray):
656+
result = NadaArray(result)
657+
return result
603658

604659
@copy_metadata(np.ndarray.take)
605660
def take(self, *args, **kwargs):
606-
return self.inner.take(*args, **kwargs)
661+
result = self.inner.take(*args, **kwargs)
662+
if isinstance(result, np.ndarray):
663+
result = NadaArray(result)
664+
return result
607665

608666
@copy_metadata(np.ndarray.tolist)
609667
def tolist(self, *args, **kwargs):
610-
return self.inner.tolist(*args, **kwargs)
668+
result = self.inner.tolist(*args, **kwargs)
669+
if isinstance(result, np.ndarray):
670+
result = NadaArray(result)
671+
return result
611672

612673
@copy_metadata(np.ndarray.trace)
613674
def trace(self, *args, **kwargs):
614-
return self.inner.trace(*args, **kwargs)
675+
result = self.inner.trace(*args, **kwargs)
676+
if isinstance(result, np.ndarray):
677+
result = NadaArray(result)
678+
return result
615679

616680
@copy_metadata(np.ndarray.transpose)
617681
def transpose(self, *args, **kwargs):
618-
return self.inner.transpose(*args, **kwargs)
682+
result = self.inner.transpose(*args, **kwargs)
683+
if isinstance(result, np.ndarray):
684+
result = NadaArray(result)
685+
return result
619686

620687
@property
621688
@copy_metadata(np.ndarray.base)
622689
def base(self):
623-
return self.inner.base
690+
result = self.inner.base
691+
if isinstance(result, np.ndarray):
692+
result = NadaArray(result)
693+
return result
624694

625695
@property
626696
@copy_metadata(np.ndarray.data)
627697
def data(self):
628-
return self.inner.data
698+
result = self.inner.data
699+
if isinstance(result, np.ndarray):
700+
result = NadaArray(result)
701+
return result
629702

630703
@property
631704
@copy_metadata(np.ndarray.flags)
632705
def flags(self):
633-
return self.inner.flags
706+
result = self.inner.flags
707+
if isinstance(result, np.ndarray):
708+
result = NadaArray(result)
709+
return result
634710

635711
@property
636712
@copy_metadata(np.ndarray.flat)
637713
def flat(self):
638-
return self.inner.flat
714+
result = self.inner.flat
715+
if isinstance(result, np.ndarray):
716+
result = NadaArray(result)
717+
return result
639718

640719
@property
641720
@copy_metadata(np.ndarray.itemsize)
642721
def itemsize(self):
643-
return self.inner.itemsize
722+
result = self.inner.itemsize
723+
if isinstance(result, np.ndarray):
724+
result = NadaArray(result)
725+
return result
644726

645727
@property
646728
@copy_metadata(np.ndarray.nbytes)
647729
def nbytes(self):
648-
return self.inner.nbytes
730+
result = self.inner.nbytes
731+
if isinstance(result, np.ndarray):
732+
result = NadaArray(result)
733+
return result
649734

650735
@property
651736
@copy_metadata(np.ndarray.ndim)
652737
def ndim(self):
653-
return self.inner.ndim
738+
result = self.inner.ndim
739+
if isinstance(result, np.ndarray):
740+
result = NadaArray(result)
741+
return result
654742

655743
@property
656744
@copy_metadata(np.ndarray.shape)
657745
def shape(self):
658-
return self.inner.shape
746+
result = self.inner.shape
747+
if isinstance(result, np.ndarray):
748+
result = NadaArray(result)
749+
return result
659750

660751
@property
661752
@copy_metadata(np.ndarray.size)
662753
def size(self):
663-
return self.inner.size
754+
result = self.inner.size
755+
if isinstance(result, np.ndarray):
756+
result = NadaArray(result)
757+
return result
664758

665759
@property
666760
@copy_metadata(np.ndarray.strides)
667761
def strides(self):
668-
return self.inner.strides
762+
result = self.inner.strides
763+
if isinstance(result, np.ndarray):
764+
result = NadaArray(result)
765+
return result
669766

670767
@property
671768
@copy_metadata(np.ndarray.T)
672769
def T(self):
673-
return self.inner.T
770+
result = self.inner.T
771+
if isinstance(result, np.ndarray):
772+
result = NadaArray(result)
773+
return result

nada_algebra/funcs.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,85 @@ def size(arr: NadaArray) -> int:
315315
return arr.size
316316

317317

318+
def to_nada(arr: np.ndarray, nada_type: _NadaCleartextType) -> NadaArray:
319+
"""
320+
Converts a plain-text NumPy array to the equivalent NadaArray with
321+
a specified compatible NadaType.
322+
323+
Args:
324+
arr (np.ndarray): Input Numpy array.
325+
nada_type (_NadaCleartextType): Desired clear-text NadaType.
326+
327+
Returns:
328+
NadaArray: Output NadaArray.
329+
"""
330+
if nada_type == Rational:
331+
nada_type = rational
332+
else:
333+
arr = arr.astype(int)
334+
return NadaArray(np.frompyfunc(nada_type, 1, 1)(arr))
335+
336+
337+
@copy_metadata(np.pad)
338+
def pad(
339+
arr: NadaArray,
340+
pad_width: Union[Iterable[int], int],
341+
mode: str = "constant",
342+
**kwargs,
343+
) -> NadaArray:
344+
if mode not in {"constant", "edge", "reflect", "symmetric", "wrap"}:
345+
raise NotImplementedError(
346+
"Not currently possible to pad NadaArray in mode `%s`" % mode
347+
)
348+
349+
# Override python defaults by NadaType defaults
350+
overriden_kwargs = {}
351+
if mode == "constant":
352+
dtype = arr.dtype
353+
if dtype in (Rational, SecretRational):
354+
nada_type = rational
355+
elif dtype in (PublicInteger, SecretInteger):
356+
nada_type = Integer
357+
elif dtype == (PublicUnsignedInteger, SecretUnsignedInteger):
358+
nada_type = UnsignedInteger
359+
else:
360+
nada_type = dtype
361+
362+
overriden_kwargs["constant_values"] = kwargs.get(
363+
"constant_values", nada_type(0)
364+
)
365+
366+
padded_inner = np.pad(
367+
arr,
368+
pad_width,
369+
mode,
370+
**overriden_kwargs,
371+
**kwargs,
372+
)
373+
374+
return NadaArray(padded_inner)
375+
376+
377+
@copy_metadata(np.eye)
378+
def eye(*args, nada_type: _NadaCleartextType, **kwargs) -> NadaArray:
379+
return to_nada(np.eye(*args, **kwargs), nada_type)
380+
381+
382+
@copy_metadata(np.arange)
383+
def arange(*args, nada_type: _NadaCleartextType, **kwargs) -> NadaArray:
384+
return to_nada(np.arange(*args, **kwargs), nada_type)
385+
386+
387+
@copy_metadata(np.linspace)
388+
def linspace(*args, nada_type: _NadaCleartextType, **kwargs) -> NadaArray:
389+
return to_nada(np.linspace(*args, **kwargs), nada_type)
390+
391+
392+
@copy_metadata(np.split)
393+
def split(a: NadaArray, *args, **kwargs) -> NadaArray:
394+
return NadaArray(np.split(a.inner, *args, **kwargs))
395+
396+
318397
@copy_metadata(np.compress)
319398
def compress(a: NadaArray, *args, **kwargs):
320399
return a.compress(*args, **kwargs)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "nada-algebra"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Nada-Algebra is a Python library designed for algebraic operations on NumPy-like array objects on top of Nada DSL and Nillion Network."
55
authors = ["José Cabrero-Holgueras <jose.cabrero@nillion.com>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)