Skip to content

Commit 0d6831e

Browse files
Feature/numpy funcs (#22)
Add NadaArray methods as functional calls
1 parent 7910167 commit 0d6831e

File tree

7 files changed

+408
-98
lines changed

7 files changed

+408
-98
lines changed

nada_algebra/array.py

Lines changed: 144 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
secret_rational,
2727
get_log_scale,
2828
)
29+
from nada_algebra.utils import copy_metadata
2930

3031

3132
@dataclass
@@ -39,42 +40,6 @@ class NadaArray:
3940

4041
inner: np.ndarray
4142

42-
SUPPORTED_OPERATIONS = {
43-
"base",
44-
"compress",
45-
"copy",
46-
"cumprod",
47-
"cumsum",
48-
"data",
49-
"diagonal",
50-
"fill",
51-
"flags",
52-
"flat",
53-
"flatten",
54-
"item",
55-
"itemset",
56-
"itemsize",
57-
"nbytes",
58-
"ndim",
59-
"prod",
60-
"put",
61-
"ravel",
62-
"repeat",
63-
"reshape",
64-
"resize",
65-
"shape",
66-
"size",
67-
"squeeze",
68-
"strides",
69-
"sum",
70-
"swapaxes",
71-
"T",
72-
"take",
73-
"tolist",
74-
"trace",
75-
"transpose",
76-
}
77-
7843
def __getitem__(self, item):
7944
"""
8045
Get an item from the array.
@@ -523,57 +488,6 @@ def random(
523488

524489
return NadaArray(np.array(NadaArray.create_list(dims, None, None, generator)))
525490

526-
def __getattr__(self, name: str) -> Any:
527-
"""
528-
Routes other attributes to the inner NumPy array.
529-
530-
Args:
531-
name (str): Attribute name.
532-
533-
Raises:
534-
AttributeError: Raised if attribute not supported.
535-
536-
Returns:
537-
Any: Result of attribute.
538-
"""
539-
if name not in self.SUPPORTED_OPERATIONS:
540-
raise AttributeError(
541-
"NumPy method `%s` is not (currently) supported by NadaArrays." % name
542-
)
543-
544-
attr = getattr(self.inner, name)
545-
546-
if callable(attr):
547-
548-
def wrapper(*args, **kwargs):
549-
result = attr(*args, **kwargs)
550-
if isinstance(result, np.ndarray):
551-
return NadaArray(result)
552-
return result
553-
554-
return wrapper
555-
556-
if isinstance(attr, np.ndarray):
557-
attr = NadaArray(attr)
558-
559-
return attr
560-
561-
def __setattr__(self, name: str, value: Any):
562-
"""
563-
Overrides the default behavior of setting attributes.
564-
565-
If the attribute name is "inner", it sets the attribute value directly.
566-
Otherwise, it sets the attribute value on the inner object.
567-
568-
Args:
569-
name (str): The name of the attribute.
570-
value: The value to set for the attribute.
571-
"""
572-
if name == "inner":
573-
super().__setattr__(name, value)
574-
else:
575-
setattr(self.inner, name, value)
576-
577491
def __len__(self):
578492
"""
579493
Overrides the default behavior of returning the length of the object.
@@ -583,16 +497,6 @@ def __len__(self):
583497
"""
584498
return len(self.inner)
585499

586-
@property
587-
def ndim(self) -> int:
588-
"""
589-
Number of dimensions that the NadaArray holds.
590-
591-
Returns:
592-
int: Number of dimenions.
593-
"""
594-
return len(self.shape)
595-
596500
@property
597501
def empty(self) -> bool:
598502
"""
@@ -624,3 +528,146 @@ def is_rational(self) -> bool:
624528
bool: Boolean output.
625529
"""
626530
return self.dtype in (Rational, SecretRational)
531+
532+
@copy_metadata(np.ndarray.compress)
533+
def compress(self, *args, **kwargs):
534+
return self.inner.compress(*args, **kwargs)
535+
536+
@copy_metadata(np.ndarray.copy)
537+
def copy(self, *args, **kwargs):
538+
return self.inner.copy(*args, **kwargs)
539+
540+
@copy_metadata(np.ndarray.cumprod)
541+
def cumprod(self, *args, **kwargs):
542+
return self.inner.cumprod(*args, **kwargs)
543+
544+
@copy_metadata(np.ndarray.cumsum)
545+
def cumsum(self, *args, **kwargs):
546+
return self.inner.cumsum(*args, **kwargs)
547+
548+
@copy_metadata(np.ndarray.diagonal)
549+
def diagonal(self, *args, **kwargs):
550+
return self.inner.diagonal(*args, **kwargs)
551+
552+
@copy_metadata(np.ndarray.fill)
553+
def fill(self, *args, **kwargs):
554+
return self.inner.fill(*args, **kwargs)
555+
556+
@copy_metadata(np.ndarray.flatten)
557+
def flatten(self, *args, **kwargs):
558+
return self.inner.flatten(*args, **kwargs)
559+
560+
@copy_metadata(np.ndarray.item)
561+
def item(self, *args, **kwargs):
562+
return self.inner.item(*args, **kwargs)
563+
564+
@copy_metadata(np.ndarray.itemset)
565+
def itemset(self, *args, **kwargs):
566+
return self.inner.itemset(*args, **kwargs)
567+
568+
@copy_metadata(np.ndarray.prod)
569+
def prod(self, *args, **kwargs):
570+
return self.inner.prod(*args, **kwargs)
571+
572+
@copy_metadata(np.ndarray.put)
573+
def put(self, *args, **kwargs):
574+
return self.inner.put(*args, **kwargs)
575+
576+
@copy_metadata(np.ndarray.ravel)
577+
def ravel(self, *args, **kwargs):
578+
return self.inner.ravel(*args, **kwargs)
579+
580+
@copy_metadata(np.ndarray.repeat)
581+
def repeat(self, *args, **kwargs):
582+
return self.inner.repeat(*args, **kwargs)
583+
584+
@copy_metadata(np.ndarray.reshape)
585+
def reshape(self, *args, **kwargs):
586+
return self.inner.reshape(*args, **kwargs)
587+
588+
@copy_metadata(np.ndarray.resize)
589+
def resize(self, *args, **kwargs):
590+
return self.inner.resize(*args, **kwargs)
591+
592+
@copy_metadata(np.ndarray.squeeze)
593+
def squeeze(self, *args, **kwargs):
594+
return self.inner.squeeze(*args, **kwargs)
595+
596+
@copy_metadata(np.ndarray.sum)
597+
def sum(self, *args, **kwargs):
598+
return self.inner.sum(*args, **kwargs)
599+
600+
@copy_metadata(np.ndarray.swapaxes)
601+
def swapaxes(self, *args, **kwargs):
602+
return self.inner.swapaxes(*args, **kwargs)
603+
604+
@copy_metadata(np.ndarray.take)
605+
def take(self, *args, **kwargs):
606+
return self.inner.take(*args, **kwargs)
607+
608+
@copy_metadata(np.ndarray.tolist)
609+
def tolist(self, *args, **kwargs):
610+
return self.inner.tolist(*args, **kwargs)
611+
612+
@copy_metadata(np.ndarray.trace)
613+
def trace(self, *args, **kwargs):
614+
return self.inner.trace(*args, **kwargs)
615+
616+
@copy_metadata(np.ndarray.transpose)
617+
def transpose(self, *args, **kwargs):
618+
return self.inner.transpose(*args, **kwargs)
619+
620+
@property
621+
@copy_metadata(np.ndarray.base)
622+
def base(self):
623+
return self.inner.base
624+
625+
@property
626+
@copy_metadata(np.ndarray.data)
627+
def data(self):
628+
return self.inner.data
629+
630+
@property
631+
@copy_metadata(np.ndarray.flags)
632+
def flags(self):
633+
return self.inner.flags
634+
635+
@property
636+
@copy_metadata(np.ndarray.flat)
637+
def flat(self):
638+
return self.inner.flat
639+
640+
@property
641+
@copy_metadata(np.ndarray.itemsize)
642+
def itemsize(self):
643+
return self.inner.itemsize
644+
645+
@property
646+
@copy_metadata(np.ndarray.nbytes)
647+
def nbytes(self):
648+
return self.inner.nbytes
649+
650+
@property
651+
@copy_metadata(np.ndarray.ndim)
652+
def ndim(self):
653+
return self.inner.ndim
654+
655+
@property
656+
@copy_metadata(np.ndarray.shape)
657+
def shape(self):
658+
return self.inner.shape
659+
660+
@property
661+
@copy_metadata(np.ndarray.size)
662+
def size(self):
663+
return self.inner.size
664+
665+
@property
666+
@copy_metadata(np.ndarray.strides)
667+
def strides(self):
668+
return self.inner.strides
669+
670+
@property
671+
@copy_metadata(np.ndarray.T)
672+
def T(self):
673+
return self.inner.T

0 commit comments

Comments
 (0)