Skip to content

Commit 0dff731

Browse files
mathias-nillionjcabrero
authored andcommitted
feat: Add zeros and ones like and other numpy-nada compatible operations (#6)
1 parent ac4dc01 commit 0dff731

File tree

12 files changed

+370
-38
lines changed

12 files changed

+370
-38
lines changed

nada_algebra/array.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Callable, Union
7+
from typing import Any, Callable, Union
88

99
import numpy as np
1010
from nada_dsl import (
@@ -31,6 +31,35 @@ class NadaArray:
3131

3232
inner: np.ndarray
3333

34+
SUPPORTED_OPERATIONS = {
35+
"compress",
36+
"copy",
37+
"cumprod",
38+
"cumsum",
39+
"diagonal",
40+
"fill",
41+
"flatten",
42+
"item",
43+
"itemset",
44+
"prod",
45+
"put",
46+
"ravel",
47+
"repeat",
48+
"reshape",
49+
"resize",
50+
"shape",
51+
"size",
52+
"squeeze",
53+
"sum",
54+
"swapaxes",
55+
"T",
56+
"take",
57+
"tolist",
58+
"trace",
59+
"transpose",
60+
}
61+
62+
3463
def __getitem__(self, item):
3564
"""
3665
Get an item from the array.
@@ -59,18 +88,6 @@ def __setitem__(self, key, value):
5988
else:
6089
self.inner[key] = value
6190

62-
def __getattr__(self, name: str):
63-
"""
64-
Get an attribute from the array.
65-
66-
Args:
67-
name (str): The attribute name.
68-
69-
Returns:
70-
NadaArray: A new NadaArray representing the retrieved attribute.
71-
"""
72-
return getattr(self.inner, name)
73-
7491
def __add__(
7592
self,
7693
other: Union[
@@ -220,15 +237,6 @@ def dot(self, other: "NadaArray") -> "NadaArray":
220237
"""
221238
return NadaArray(self.inner.dot(other.inner))
222239

223-
def sum(self) -> Union[SecretInteger, SecretUnsignedInteger]:
224-
"""
225-
Compute the sum of the elements in the array.
226-
227-
Returns:
228-
Union[SecretInteger, SecretUnsignedInteger]: The sum of the array elements.
229-
"""
230-
return NadaArray(self.inner.sum())
231-
232240
def hstack(self, other: "NadaArray") -> "NadaArray":
233241
"""
234242
Horizontally stack two NadaArray objects.
@@ -424,3 +432,39 @@ def random(
424432
)
425433
)
426434
)
435+
436+
def __getattr__(self, name: str) -> Any:
437+
"""Routes other attributes to the inner NumPy array.
438+
439+
Args:
440+
name (str): Attribute name.
441+
442+
Raises:
443+
AttributeError: Raised if attribute not supported.
444+
445+
Returns:
446+
Any: Result of attribute.
447+
"""
448+
if name not in self.SUPPORTED_OPERATIONS:
449+
raise AttributeError("NumPy method `%s` is not (currently) supported by NadaArrays." % name)
450+
451+
attr = getattr(self.inner, name)
452+
453+
if callable(attr):
454+
def wrapper(*args, **kwargs):
455+
result = attr(*args, **kwargs)
456+
if isinstance(result, np.ndarray):
457+
return NadaArray(result)
458+
return result
459+
return wrapper
460+
461+
if isinstance(attr, np.ndarray):
462+
attr = NadaArray(attr)
463+
464+
return attr
465+
466+
def __setattr__(self, name, value):
467+
if name == 'inner':
468+
super().__setattr__(name, value)
469+
else:
470+
setattr(self.inner, name, value)

nada_algebra/funcs.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
and manipulation of arrays and party objects.
44
"""
55

6+
from typing import Any, Iterable
67
from nada_dsl import (
78
Party,
89
SecretInteger,
@@ -30,20 +31,20 @@ def parties(num: int, prefix: str = "Party") -> list:
3031
return [Party(name=f"{prefix}{i}") for i in range(num)]
3132

3233

33-
def __from_list(lst: list, nada_type: Integer | UnsignedInteger) -> list:
34+
def __from_numpy(arr: np.ndarray, nada_type: Integer | UnsignedInteger) -> list:
3435
"""
35-
Recursively convert a nested list to a list of NadaInteger objects.
36+
Recursively convert a n-dimensional NumPy array to a nested list of NadaInteger objects.
3637
3738
Args:
38-
lst (list): A nested list of integers.
39+
arr (np.ndarray): A NumPy array of integers.
3940
nada_type (type): The type of NadaInteger objects to create.
4041
4142
Returns:
4243
list: A nested list of NadaInteger objects.
4344
"""
44-
if len(lst.shape) == 1:
45-
return [nada_type(int(elem)) for elem in lst]
46-
return [__from_list(lst[i], nada_type) for i in range(len(lst))]
45+
if len(arr.shape) == 1:
46+
return [nada_type(int(elem)) for elem in arr]
47+
return [__from_numpy(arr[i], nada_type) for i in range(arr.shape[0])]
4748

4849

4950
def from_list(lst: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
@@ -59,15 +60,15 @@ def from_list(lst: list, nada_type: Integer | UnsignedInteger = Integer) -> Nada
5960
"""
6061
if not isinstance(lst, np.ndarray):
6162
lst = np.array(lst)
62-
return NadaArray(np.array(__from_list(lst, nada_type)))
63+
return NadaArray(np.array(__from_numpy(lst, nada_type)))
6364

6465

65-
def ones(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
66+
def ones(dims: Iterable[int], nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
6667
"""
6768
Create a cleartext NadaArray filled with ones.
6869
6970
Args:
70-
dims (list): A list of integers representing the dimensions of the array.
71+
dims (Iterable[int]): A list of integers representing the dimensions of the array.
7172
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
7273
7374
Returns:
@@ -76,12 +77,28 @@ def ones(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArra
7677
return from_list(np.ones(dims), nada_type)
7778

7879

79-
def zeros(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
80+
def ones_like(a: np.ndarray | NadaArray, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
81+
"""
82+
Create a cleartext NadaArray filled with one with the same shape and type as a given array.
83+
84+
Args:
85+
a (np.ndarray | NadaArray): A reference array.
86+
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
87+
88+
Returns:
89+
NadaArray: The created NadaArray filled with ones.
90+
"""
91+
if isinstance(a, NadaArray):
92+
a = a.inner
93+
return from_list(np.ones_like(a), nada_type)
94+
95+
96+
def zeros(dims: Iterable[int], nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
8097
"""
8198
Create a cleartext NadaArray filled with zeros.
8299
83100
Args:
84-
dims (list): A list of integers representing the dimensions of the array.
101+
dims (Iterable[int]): A list of integers representing the dimensions of the array.
85102
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
86103
87104
Returns:
@@ -90,8 +107,56 @@ def zeros(dims: list, nada_type: Integer | UnsignedInteger = Integer) -> NadaArr
90107
return from_list(np.zeros(dims), nada_type)
91108

92109

110+
def zeros_like(a: np.ndarray | NadaArray, nada_type: Integer | UnsignedInteger = Integer) -> NadaArray:
111+
"""
112+
Create a cleartext NadaArray filled with zeros with the same shape and type as a given array.
113+
114+
Args:
115+
a (np.ndarray | NadaArray): A reference array.
116+
nada_type (type, optional): The type of NadaInteger objects to create. Defaults to Integer.
117+
118+
Returns:
119+
NadaArray: The created NadaArray filled with zeros.
120+
"""
121+
if isinstance(a, NadaArray):
122+
a = a.inner
123+
return from_list(np.zeros_like(a), nada_type)
124+
125+
126+
def alphas(dims: Iterable[int], alpha: Any) -> NadaArray:
127+
"""
128+
Create a NadaArray filled with a certain constant value.
129+
130+
Args:
131+
dims (Iterable[int]): A list of integers representing the dimensions of the array.
132+
alpha (Any): Some constant value.
133+
134+
Returns:
135+
NadaArray: NadaArray filled with constant value.
136+
"""
137+
ones_array = np.ones(dims)
138+
return NadaArray(np.frompyfunc(lambda _: alpha, 1, 1)(ones_array))
139+
140+
141+
def alphas_like(a: np.ndarray | NadaArray, alpha: Any) -> NadaArray:
142+
"""
143+
Create a NadaArray filled with a certain constant value with the same shape and type as a given array.
144+
145+
Args:
146+
a (np.ndarray | NadaArray): Reference array.
147+
alpha (Any): Some constant value.
148+
149+
Returns:
150+
NadaArray: NadaArray filled with constant value.
151+
"""
152+
if isinstance(a, NadaArray):
153+
a = a.inner
154+
ones_array = np.ones_like(a)
155+
return NadaArray(np.frompyfunc(lambda _: alpha, 1, 1)(ones_array))
156+
157+
93158
def array(
94-
dims: list,
159+
dims: Iterable[int],
95160
party: Party,
96161
prefix: str,
97162
nada_type: (
@@ -102,7 +167,7 @@ def array(
102167
Create a NadaArray with the specified dimensions and elements of the given type.
103168
104169
Args:
105-
dims (list): A list of integers representing the dimensions of the array.
170+
dims (Iterable[int]): A list of integers representing the dimensions of the array.
106171
party (Party): The party object.
107172
prefix (str): A prefix for naming the array elements.
108173
nada_type (type, optional): The type of elements to create. Defaults to SecretInteger.
@@ -114,13 +179,13 @@ def array(
114179

115180

116181
def random(
117-
dims: list, nada_type: SecretInteger | SecretUnsignedInteger = SecretInteger
182+
dims: Iterable[int], nada_type: SecretInteger | SecretUnsignedInteger = SecretInteger
118183
) -> NadaArray:
119184
"""
120185
Create a random NadaArray with the specified dimensions.
121186
122187
Args:
123-
dims (list): A list of integers representing the dimensions of the array.
188+
dims (Iterable[int]): A list of integers representing the dimensions of the array.
124189
nada_type (type, optional): The type of elements to create. Defaults to SecretInteger.
125190
126191
Returns:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "generate_array"
2+
version = "0.1.0"
3+
authors = [""]
4+
5+
[[programs]]
6+
path = "src/main.py"
7+
prime_size = 128

tests/generate_array/src/main.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from nada_dsl import *
2+
import nada_algebra as na
3+
4+
5+
def nada_main():
6+
party = Party("party_0")
7+
8+
a = SecretInteger(Input("a", party))
9+
10+
ones1 = na.ones([2, 3])
11+
ones2 = na.ones_like(ones1)
12+
13+
zeros1 = na.zeros([2, 3])
14+
zeros2 = na.zeros_like(zeros1)
15+
16+
alphas1 = na.alphas([2, 3], alpha=a)
17+
alphas2 = na.alphas_like(alphas1, alpha=a)
18+
19+
two_a = alphas1 + alphas2
20+
21+
out = two_a + zeros1 + zeros2 + ones1 + ones2
22+
23+
return out.output(party, "my_output")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This directory is kept purposely, so that no compilation errors arise.
2+
# Ignore everything in this directory
3+
*
4+
# Except this file
5+
!.gitignore

tests/generate_array/tests/base.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
program: main
3+
inputs:
4+
secrets:
5+
a:
6+
SecretInteger: "3"
7+
public_variables: {}
8+
expected_outputs:
9+
my_output_0_0:
10+
SecretInteger: "8"
11+
my_output_0_1:
12+
SecretInteger: "8"
13+
my_output_0_2:
14+
SecretInteger: "8"
15+
my_output_1_0:
16+
SecretInteger: "8"
17+
my_output_1_1:
18+
SecretInteger: "8"
19+
my_output_1_2:
20+
SecretInteger: "8"

tests/sum/src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def nada_main():
99

1010
result = a.sum()
1111

12-
return result.output(parties[1], "my_output")
12+
return [Output(result, "my_output", parties[1])]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "supported_operations"
2+
version = "0.1.0"
3+
authors = [""]
4+
5+
[[programs]]
6+
path = "src/main.py"
7+
prime_size = 128

0 commit comments

Comments
 (0)