Skip to content

Commit 879ef1f

Browse files
Feature/array statistics (#24)
Add arithmetic mean operation
1 parent 9567774 commit 879ef1f

File tree

6 files changed

+129
-0
lines changed

6 files changed

+129
-0
lines changed

nada_algebra/array.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nada_algebra.types import (
2323
Rational,
2424
SecretRational,
25+
rational,
2526
public_rational,
2627
secret_rational,
2728
get_log_scale,
@@ -326,6 +327,34 @@ def apply(self, func: Callable[[Any], Any]) -> "NadaArray":
326327
"""
327328
return NadaArray(np.frompyfunc(func, 1, 1)(self.inner))
328329

330+
@copy_metadata(np.ndarray.mean)
331+
def mean(self, axis=None, dtype=None, out=None, keepdims=False) -> Any:
332+
sum_arr = self.inner.sum(axis=axis, dtype=dtype, keepdims=keepdims)
333+
334+
if self.dtype in (Rational, SecretRational):
335+
nada_type = rational
336+
else:
337+
nada_type = Integer
338+
339+
if axis is None:
340+
count = nada_type(self.size)
341+
else:
342+
if keepdims:
343+
count = np.expand_dims(count, axis=axis)
344+
count = np.frompyfunc(nada_type, 1, 1)(count)
345+
else:
346+
count = nada_type(self.shape[axis])
347+
348+
mean_arr = sum_arr / count
349+
350+
if out is not None:
351+
out[...] = mean_arr
352+
return out
353+
354+
if isinstance(mean_arr, np.ndarray):
355+
return NadaArray(mean_arr)
356+
return mean_arr
357+
329358
@staticmethod
330359
def output_array(array: np.ndarray, party: Party, prefix: str) -> list:
331360
"""

nada_algebra/funcs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,11 @@ def diagonal(a: NadaArray, *args, **kwargs):
419419
return a.diagonal(*args, **kwargs)
420420

421421

422+
@copy_metadata(np.diagonal)
423+
def mean(a: NadaArray, *args, **kwargs):
424+
return a.mean(*args, **kwargs)
425+
426+
422427
@copy_metadata(np.prod)
423428
def prod(a: NadaArray, *args, **kwargs):
424429
return a.prod(*args, **kwargs)

tests/nada-tests/nada-project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,7 @@ prime_size = 128
141141
[[programs]]
142142
path = "src/functional_operations.py"
143143
prime_size = 128
144+
145+
[[programs]]
146+
path = "src/array_statistics.py"
147+
prime_size = 128
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from nada_dsl import *
2+
import nada_algebra as na
3+
4+
5+
def nada_main():
6+
parties = na.parties(2)
7+
8+
a = na.array([3, 2], parties[0], "A", SecretInteger)
9+
b = na.array([3, 2], parties[0], "B", na.SecretRational)
10+
11+
a_sum = a.sum()
12+
b_sum = b.sum()
13+
14+
a_sum_arr = a.sum(axis=0)
15+
b_sum_arr = b.sum(axis=0)
16+
17+
a_mean = a.mean()
18+
b_mean = b.mean()
19+
20+
a_mean_arr = a.mean(axis=0)
21+
b_mean_arr = b.mean(axis=0)
22+
23+
output_1 = [
24+
Output(a_sum, "a_sum", parties[1]),
25+
Output(a_mean, "a_mean", parties[1]),
26+
Output(b_sum.value, "b_sum", parties[1]),
27+
Output(b_mean.value, "b_mean", parties[1]),
28+
]
29+
output_2 = (
30+
a_sum_arr.output(parties[1], "a_sum_arr")
31+
+ b_sum_arr.output(parties[1], "b_sum_arr")
32+
+ a_mean_arr.output(parties[1], "a_mean_arr")
33+
+ b_mean_arr.output(parties[1], "b_mean_arr")
34+
)
35+
36+
return output_1 + output_2
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
---
2+
program: array_statistics
3+
inputs:
4+
secrets:
5+
A_0_0:
6+
SecretInteger: "1"
7+
A_0_1:
8+
SecretInteger: "2"
9+
A_1_0:
10+
SecretInteger: "4"
11+
A_1_1:
12+
SecretInteger: "5"
13+
A_2_0:
14+
SecretInteger: "4"
15+
A_2_1:
16+
SecretInteger: "5"
17+
B_0_0:
18+
SecretInteger: "1"
19+
B_0_1:
20+
SecretInteger: "2"
21+
B_1_0:
22+
SecretInteger: "4"
23+
B_1_1:
24+
SecretInteger: "5"
25+
B_2_0:
26+
SecretInteger: "4"
27+
B_2_1:
28+
SecretInteger: "5"
29+
public_variables: {}
30+
expected_outputs:
31+
a_mean_arr_1:
32+
SecretInteger: "4"
33+
b_mean_arr_0:
34+
SecretInteger: "3"
35+
b_sum:
36+
SecretInteger: "21"
37+
b_mean:
38+
SecretInteger: "3"
39+
a_mean_arr_0:
40+
SecretInteger: "3"
41+
b_mean_arr_1:
42+
SecretInteger: "4"
43+
b_sum_arr_0:
44+
SecretInteger: "9"
45+
a_sum:
46+
SecretInteger: "21"
47+
a_sum_arr_0:
48+
SecretInteger: "9"
49+
b_sum_arr_1:
50+
SecretInteger: "12"
51+
a_mean:
52+
SecretInteger: "3"
53+
a_sum_arr_1:
54+
SecretInteger: "12"

tests/test_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"rational_advanced",
3737
"array_attributes",
3838
"functional_operations",
39+
"array_statistics",
3940
# Not supported yet
4041
# "unsigned_matrix_inverse",
4142
# "private_inverse"

0 commit comments

Comments
 (0)