Skip to content

Commit b86dbfc

Browse files
Feature/frompyfunc (#26)
Add vectorize and frompyfunc logic
1 parent e2403a3 commit b86dbfc

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

nada_algebra/funcs.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
and manipulation of arrays and party objects.
44
"""
55

6-
from typing import Any, Iterable, Tuple, Union
6+
from typing import Any, Callable, Iterable, Sequence, Tuple, Union
77
from nada_dsl import (
88
Party,
99
SecretInteger,
@@ -374,6 +374,47 @@ def pad(
374374
return NadaArray(padded_inner)
375375

376376

377+
class NadaCallable:
378+
"""Class that wraps a vectorized function to ensure all NumPy outputs are converted to NadaArray"""
379+
380+
def __init__(self, vfunc: Callable) -> None:
381+
"""
382+
Initialization.
383+
384+
Args:
385+
vfunc (Callable): Vectorized function to wrap.
386+
"""
387+
self.vfunc = vfunc
388+
389+
def __call__(self, *args, **kwargs) -> Any:
390+
"""
391+
Routes function call to wrapped vectorized function while
392+
ensuring any resulting NumPy arrays are converted to NadaArrays.
393+
394+
Returns:
395+
Any: Function result.
396+
"""
397+
result = self.vfunc(*args, **kwargs)
398+
if isinstance(result, np.ndarray):
399+
return NadaArray(result)
400+
if isinstance(result, Sequence):
401+
return type(result)(
402+
NadaArray(value) if isinstance(value, np.ndarray) else value
403+
for value in result
404+
)
405+
return result
406+
407+
408+
@copy_metadata(np.frompyfunc)
409+
def frompyfunc(*args, **kwargs) -> NadaCallable:
410+
return NadaCallable(np.frompyfunc(*args, **kwargs))
411+
412+
413+
@copy_metadata(np.vectorize)
414+
def vectorize(*args, **kwargs) -> NadaCallable:
415+
return NadaCallable(np.vectorize(*args, **kwargs))
416+
417+
377418
@copy_metadata(np.eye)
378419
def eye(*args, nada_type: _NadaCleartextType, **kwargs) -> NadaArray:
379420
return to_nada(np.eye(*args, **kwargs), nada_type)

tests/nada-tests/src/functional_operations.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def nada_main():
3333
_ = na.pad(a, 2, mode="wrap")
3434
_ = na.split(a, (1, 2))
3535

36+
pyfunc_out_1 = na.frompyfunc(lambda x: x + Integer(1), 1, 1)(a)
37+
assert isinstance(pyfunc_out_1, na.NadaArray), type(pyfunc_out_1).__name__
38+
39+
pyfunc_out_2, pyfunc_out_3 = na.frompyfunc(
40+
lambda x: (x + Integer(1), x + Integer(2)), 1, 2
41+
)(a)
42+
assert isinstance(pyfunc_out_2, na.NadaArray), type(pyfunc_out_2).__name__
43+
assert isinstance(pyfunc_out_3, na.NadaArray), type(pyfunc_out_3).__name__
44+
45+
pyfunc_out_4 = na.frompyfunc(lambda x, y: x + y, 2, 1)(a, a)
46+
assert isinstance(pyfunc_out_4, na.NadaArray), type(pyfunc_out_4).__name__
47+
48+
vectorize_out_1 = na.vectorize(lambda x: x + Integer(1))(a)
49+
assert isinstance(vectorize_out_1, na.NadaArray), type(vectorize_out_1).__name__
50+
51+
vectorize_out_2, vectorize_out_3 = na.vectorize(
52+
lambda x: (x + Integer(1), x + Integer(2))
53+
)(a)
54+
assert isinstance(vectorize_out_2, na.NadaArray), type(vectorize_out_2).__name__
55+
assert isinstance(vectorize_out_3, na.NadaArray), type(vectorize_out_3).__name__
56+
57+
vectorize_out_4 = na.vectorize(lambda x, y: x + y)(a, a)
58+
assert isinstance(vectorize_out_4, na.NadaArray), type(vectorize_out_4).__name__
59+
3660
# Test all for a Rational type
3761
_ = na.sum(b)
3862
_ = na.compress(b, [True, True, False], axis=0)
@@ -41,7 +65,7 @@ def nada_main():
4165
_ = na.cumsum(b, axis=0)
4266
_ = na.diagonal(b.reshape(1, 3))
4367
_ = na.prod(b)
44-
_ = na.put(b, 2, Integer(20))
68+
_ = na.put(b, 2, na.rational(20, is_scaled=True))
4569
_ = na.ravel(b)
4670
_ = na.repeat(b, 12)
4771
_ = na.reshape(b, (1, 3))
@@ -58,6 +82,30 @@ def nada_main():
5882
_ = na.pad(b, 2, mode="wrap")
5983
_ = na.split(b, (1, 2))
6084

85+
pyfunc_out_5 = na.frompyfunc(lambda x: x + na.rational(1), 1, 1)(b)
86+
assert isinstance(pyfunc_out_5, na.NadaArray), type(pyfunc_out_4).__name__
87+
88+
pyfunc_out_6, pyfunc_out_7 = na.frompyfunc(
89+
lambda x: (x + na.rational(1), x + na.rational(2)), 1, 2
90+
)(b)
91+
assert isinstance(pyfunc_out_6, na.NadaArray), type(pyfunc_out_6).__name__
92+
assert isinstance(pyfunc_out_7, na.NadaArray), type(pyfunc_out_7).__name__
93+
94+
pyfunc_out_8 = na.frompyfunc(lambda x, y: x + y, 2, 1)(b, b)
95+
assert isinstance(pyfunc_out_8, na.NadaArray), type(pyfunc_out_8).__name__
96+
97+
vectorize_out_5 = na.vectorize(lambda x: x + na.rational(1))(b)
98+
assert isinstance(vectorize_out_5, na.NadaArray), type(pyfunc_out_4).__name__
99+
100+
vectorize_out_6, vectorize_out_7 = na.vectorize(
101+
lambda x: (x + na.rational(1), x + na.rational(2))
102+
)(b)
103+
assert isinstance(vectorize_out_6, na.NadaArray), type(vectorize_out_6).__name__
104+
assert isinstance(vectorize_out_7, na.NadaArray), type(vectorize_out_7).__name__
105+
106+
vectorize_out_8 = na.vectorize(lambda x, y: x + y)(b, b)
107+
assert isinstance(vectorize_out_8, na.NadaArray), type(vectorize_out_8).__name__
108+
61109
# Generative functions
62110
_ = na.eye(3, nada_type=na.Rational)
63111
_ = na.eye(3, nada_type=Integer)

0 commit comments

Comments
 (0)