Skip to content

Commit 69c7c1d

Browse files
Bugfix/client side rationals (#17)
Fix client-side rational logic
1 parent 161072f commit 69c7c1d

File tree

5 files changed

+143
-25
lines changed

5 files changed

+143
-25
lines changed

nada_algebra/array.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class NadaArray:
5353
"cumprod",
5454
"cumsum",
5555
"data",
56-
"dtype",
5756
"diagonal",
57+
"dtype",
5858
"fill",
5959
"flags",
6060
"flat",
@@ -64,11 +64,6 @@ class NadaArray:
6464
"itemsize",
6565
"nbytes",
6666
"ndim",
67-
"diagonal",
68-
"fill",
69-
"flatten",
70-
"item",
71-
"itemset",
7267
"prod",
7368
"put",
7469
"ravel",
@@ -77,8 +72,8 @@ class NadaArray:
7772
"resize",
7873
"shape",
7974
"size",
80-
"strides",
8175
"squeeze",
76+
"strides",
8277
"sum",
8378
"swapaxes",
8479
"T",

nada_algebra/client.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
PublicVariableUnsignedInteger,
1212
)
1313
import numpy as np
14-
from nada_algebra.types import RationalConfig, Rational, SecretRational
14+
from nada_algebra import types
1515

1616

1717
def parties(num: int, prefix: str = "Party") -> list:
@@ -36,8 +36,8 @@ def array(
3636
SecretUnsignedInteger,
3737
PublicVariableInteger,
3838
PublicVariableUnsignedInteger,
39-
Rational,
40-
SecretRational,
39+
types.Rational,
40+
types.SecretRational,
4141
] = SecretInteger,
4242
) -> dict:
4343
"""
@@ -47,15 +47,35 @@ def array(
4747
Args:
4848
arr (np.ndarray): The input array.
4949
prefix (str): The prefix to be added to the output names.
50-
nada_type (Union[type[SecretInteger], type[SecretUnsignedInteger], \
51-
type[PublicVariableInteger], type[PublicVariableUnsignedInteger]], optional):
52-
The type of the values introduced. Defaults to SecretInteger.
50+
nada_type (type, optional): The type of the values introduced. Defaults to SecretInteger.
5351
5452
Returns:
5553
dict: A dictionary mapping generated names to Nillion input objects.
5654
"""
55+
# TODO: remove check for zero values when pushing zero secrets is supported
5756
if len(arr.shape) == 1:
58-
return {f"{prefix}_{i}": nada_type(int(arr[i])) for i in range(arr.shape[0])}
57+
if nada_type == types.Rational:
58+
return {
59+
f"{prefix}_{i}": (PublicRational(arr[i])) for i in range(arr.shape[0])
60+
}
61+
if nada_type == types.SecretRational:
62+
return {
63+
f"{prefix}_{i}": (
64+
SecretRational(arr[i]) if arr[i] != 0 else SecretInteger(1)
65+
)
66+
for i in range(arr.shape[0])
67+
}
68+
return {
69+
f"{prefix}_{i}": (
70+
nada_type(int(arr[i]))
71+
if (
72+
nada_type in (PublicVariableInteger, PublicVariableUnsignedInteger)
73+
or int(arr[i]) != 0
74+
)
75+
else nada_type(1)
76+
)
77+
for i in range(arr.shape[0])
78+
}
5979
return {
6080
k: v
6181
for i in range(arr.shape[0])
@@ -88,7 +108,7 @@ def __rational(value: Union[float, int]) -> int:
88108
Returns:
89109
int: The integer representation of the input value.
90110
"""
91-
return round(value * (1 << RationalConfig.LOG_SCALE))
111+
return round(value * (1 << types.RationalConfig.LOG_SCALE))
92112

93113

94114
def PublicRational(value: Union[float, int]) -> PublicVariableInteger:
@@ -104,13 +124,12 @@ def PublicRational(value: Union[float, int]) -> PublicVariableInteger:
104124
return PublicVariableInteger(__rational(value))
105125

106126

107-
def SecretRational(value: Union[float, int], party: str) -> SecretInteger:
127+
def SecretRational(value: Union[float, int]) -> SecretInteger:
108128
"""
109129
Returns the integer representation of the given float value.
110130
111131
Args:
112132
value (Union[float, int]): The input value.
113-
party (str): The party name.
114133
115134
Returns:
116135
int: The integer representation of the input value.

nada_algebra/types.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
Integer,
1010
NadaType,
1111
SecretInteger,
12-
SecretUnsignedInteger,
1312
SecretBoolean,
1413
PublicBoolean,
1514
PublicInteger,
16-
PublicUnsignedInteger,
1715
)
18-
from typing import Any, Callable, Union
16+
from typing import Union
1917

2018

2119
_Number = Union[float, int, np.floating]
22-
_NadaSecretInteger = Union[SecretInteger]
2320
_NadaInteger = Union[
2421
Integer,
2522
PublicInteger,

tests/python-tests/test_client.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Nada algebra client unit tests"""
2+
3+
import numpy as np
4+
import nada_algebra.client as na_client
5+
import nada_algebra as na
6+
import py_nillion_client as nillion
7+
8+
9+
class TestClient:
10+
11+
def test_array_1(self):
12+
input_arr = np.random.randn(1)
13+
nada_array = na_client.array(input_arr, "test")
14+
15+
assert list(nada_array.keys()) == ["test_0"]
16+
17+
def test_array_2(self):
18+
input_arr = np.random.randn(1)
19+
nada_array = na_client.array(input_arr, "test", na.Rational)
20+
21+
assert list(nada_array.keys()) == ["test_0"]
22+
23+
def test_array_3(self):
24+
input_arr = np.random.randn(1)
25+
nada_array = na_client.array(input_arr, "test", na.SecretRational)
26+
27+
assert list(nada_array.keys()) == ["test_0"]
28+
29+
def test_array_4(self):
30+
input_arr = np.random.randn(1)
31+
nada_array = na_client.array(input_arr, "test", nillion.PublicVariableInteger)
32+
33+
assert list(nada_array.keys()) == ["test_0"]
34+
35+
def test_array_5(self):
36+
input_arr = np.random.randn(3)
37+
nada_array = na_client.array(input_arr, "test")
38+
39+
assert list(nada_array.keys()) == ["test_0", "test_1", "test_2"]
40+
41+
def test_array_6(self):
42+
input_arr = np.random.randn(2, 3)
43+
nada_array = na_client.array(input_arr, "test")
44+
45+
assert list(sorted(nada_array.keys())) == [
46+
"test_0_0",
47+
"test_0_1",
48+
"test_0_2",
49+
"test_1_0",
50+
"test_1_1",
51+
"test_1_2",
52+
]
53+
54+
def test_array_7(self):
55+
input_arr = np.array([])
56+
nada_array = na_client.array(input_arr, "test")
57+
58+
assert nada_array == {}
59+
60+
def test_concat(self):
61+
dict_1 = {"a": 1, "b": 2}
62+
dict_2 = {"c": 3}
63+
64+
dict_comb = na_client.concat([dict_1, dict_2])
65+
66+
assert dict_comb == {"a": 1, "b": 2, "c": 3}
67+
68+
def test_secret_rational_1(self):
69+
test_value = 1
70+
71+
rational = na_client.SecretRational(test_value)
72+
73+
assert isinstance(rational, nillion.SecretInteger)
74+
75+
rational_value = rational.value
76+
77+
assert rational_value == test_value * 2**na.types.RationalConfig.LOG_SCALE
78+
79+
def test_secret_rational_2(self):
80+
test_value = 2.5
81+
82+
rational = na_client.SecretRational(test_value)
83+
84+
assert isinstance(rational, nillion.SecretInteger)
85+
86+
rational_value = rational.value
87+
88+
assert rational_value == test_value * 2**na.types.RationalConfig.LOG_SCALE
89+
90+
def test_public_rational_1(self):
91+
test_value = 1
92+
93+
rational = na_client.PublicRational(test_value)
94+
95+
assert isinstance(rational, nillion.PublicVariableInteger)
96+
97+
rational_value = rational.value
98+
99+
assert rational_value == test_value * 2**na.types.RationalConfig.LOG_SCALE
100+
101+
def test_public_rational_2(self):
102+
test_value = 2.5
103+
104+
rational = na_client.PublicRational(test_value)
105+
106+
assert isinstance(rational, nillion.PublicVariableInteger)
107+
108+
rational_value = rational.value
109+
110+
assert rational_value == test_value * 2**na.types.RationalConfig.LOG_SCALE

tests/test_all.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,8 @@ def test_client():
114114
def test_rational_client():
115115
import nada_algebra.client as na_client # For use with Python Client
116116
import py_nillion_client as nillion
117-
import numpy as np
118-
119-
parties = na_client.parties(3)
120117

121-
secret_rational = na_client.SecretRational(3.2, parties[0])
118+
secret_rational = na_client.SecretRational(3.2)
122119

123120
assert type(secret_rational) == nillion.SecretInteger
124121

0 commit comments

Comments
 (0)