Skip to content

Commit cec24b7

Browse files
Fix type hinting (#40)
Enhance type hinting
1 parent b1dcd37 commit cec24b7

13 files changed

+62
-52
lines changed

examples/dot_product/tests/dot_product.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ inputs:
1616
SecretInteger: "3"
1717
public_variables: {}
1818
expected_outputs:
19-
my_output_0:
19+
my_output:
2020
SecretInteger: "27"

nada_numpy/funcs.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from typing import Any, Callable, List, Sequence, Tuple, Union
77

88
import numpy as np
9-
from nada_dsl import (Boolean, Integer, Party, PublicInteger,
9+
from nada_dsl import (Boolean, Integer, Output, Party, PublicInteger,
1010
PublicUnsignedInteger, SecretInteger,
1111
SecretUnsignedInteger, UnsignedInteger)
1212

1313
from nada_numpy.array import NadaArray
14-
from nada_numpy.nada_typing import NadaCleartextNumber
14+
from nada_numpy.nada_typing import AnyNadaType, NadaCleartextNumber
1515
from nada_numpy.types import Rational, SecretRational, rational
1616
from nada_numpy.utils import copy_metadata
1717

@@ -60,7 +60,7 @@
6060
]
6161

6262

63-
def parties(num: int, prefix: str = "Party") -> list:
63+
def parties(num: int, prefix: str = "Party") -> List[Party]:
6464
"""
6565
Create a list of Party objects with specified names.
6666
@@ -69,12 +69,12 @@ def parties(num: int, prefix: str = "Party") -> list:
6969
prefix (str, optional): The prefix to use for party names. Defaults to "Party".
7070
7171
Returns:
72-
list: A list of Party objects with names in the format "{prefix}{i}".
72+
List[Party]: A list of Party objects with names in the format "{prefix}{i}".
7373
"""
7474
return [Party(name=f"{prefix}{i}") for i in range(num)]
7575

7676

77-
def __from_numpy(arr: np.ndarray, nada_type: NadaCleartextNumber) -> list:
77+
def __from_numpy(arr: np.ndarray, nada_type: NadaCleartextNumber) -> List:
7878
"""
7979
Recursively convert a n-dimensional NumPy array to a nested list of NadaInteger objects.
8080
@@ -83,7 +83,7 @@ def __from_numpy(arr: np.ndarray, nada_type: NadaCleartextNumber) -> list:
8383
nada_type (type): The type of NadaInteger objects to create.
8484
8585
Returns:
86-
list: A nested list of NadaInteger objects.
86+
List: A nested list of NadaInteger objects.
8787
"""
8888
if len(arr.shape) == 1:
8989
if isinstance(nada_type, Rational):
@@ -261,20 +261,24 @@ def random(
261261
return NadaArray.random(dims, nada_type)
262262

263263

264-
def output(arr: NadaArray, party: Party, prefix: str):
264+
def output(value: Union[NadaArray, AnyNadaType], party: Party, prefix: str) -> List[Output]:
265265
"""
266-
Generate a list of Output objects for each element in the input NadaArray.
266+
Generate a list of Output objects for some provided value.
267267
268268
Args:
269-
arr (NadaArray): The input NadaArray.
269+
value (Union[NadaArray, AnyNadaType]): The input NadaArray.
270270
party (Party): The party object.
271271
prefix (str): The prefix for naming the Output objects.
272272
273273
Returns:
274-
list: A list of Output objects.
275-
"""
276-
# pylint:disable=protected-access
277-
return NadaArray._output_array(arr, party, prefix)
274+
List[Output]: A list of Output objects.
275+
"""
276+
if isinstance(value, NadaArray):
277+
# pylint:disable=protected-access
278+
return NadaArray._output_array(value, party, prefix)
279+
if isinstance(value, (Rational, SecretRational)):
280+
value = value.value
281+
return [Output(value, prefix, party)]
278282

279283

280284
def vstack(arr_list: list) -> NadaArray:

nada_numpy/nada_typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@
4444
dsl.Boolean,
4545
Rational,
4646
]
47+
48+
AnyNadaType = Union[
49+
dsl.NadaType,
50+
Rational,
51+
SecretRational,
52+
]

tests/nada-tests/tests/array_statistics.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@ expected_outputs:
3232
SecretInteger: "4"
3333
b_mean_arr_0:
3434
SecretInteger: "3"
35-
b_sum_0:
35+
b_sum:
3636
SecretInteger: "21"
37-
b_mean_0:
37+
b_mean:
3838
SecretInteger: "3"
3939
a_mean_arr_0:
4040
SecretInteger: "3"
4141
b_mean_arr_1:
4242
SecretInteger: "4"
4343
b_sum_arr_0:
4444
SecretInteger: "9"
45-
a_sum_0:
45+
a_sum:
4646
SecretInteger: "21"
4747
a_sum_arr_0:
4848
SecretInteger: "9"
4949
b_sum_arr_1:
5050
SecretInteger: "12"
51-
a_mean_0:
51+
a_mean:
5252
SecretInteger: "3"
5353
a_sum_arr_1:
5454
SecretInteger: "12"

tests/nada-tests/tests/dot_product.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ inputs:
1616
SecretInteger: "3"
1717
public_variables: {}
1818
expected_outputs:
19-
my_output_0:
19+
my_output:
2020
SecretInteger: "27"

tests/nada-tests/tests/dot_product_rational.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ inputs:
1616
SecretInteger: "196608"
1717
public_variables: {}
1818
expected_outputs:
19-
my_output_a_0:
19+
my_output_a:
2020
SecretInteger: "917504"
21-
my_output_b_0:
21+
my_output_b:
2222
SecretInteger: "917504"
23-
my_output_c_0:
23+
my_output_c:
2424
SecretInteger: "393216"
25-
my_output_d_0:
25+
my_output_d:
2626
SecretInteger: "393216"

tests/nada-tests/tests/get_attr.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ inputs:
1010
SecretInteger: "3"
1111
public_variables: {}
1212
expected_outputs:
13-
my_output_0:
13+
my_output:
1414
SecretInteger: "9"

tests/nada-tests/tests/get_item.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ inputs:
1010
SecretInteger: "3"
1111
public_variables: {}
1212
expected_outputs:
13-
my_output_0:
13+
my_output:
1414
SecretInteger: "9"

tests/nada-tests/tests/logistic_regression.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ inputs:
1818
SecretInteger: "65536"
1919
public_variables: {}
2020
expected_outputs:
21-
my_output_0:
21+
my_output:
2222
SecretInteger: "12884967424"

tests/nada-tests/tests/rational_if_else.yaml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,29 @@ inputs:
1414
SecretUnsignedInteger: "1"
1515
public_variables: {}
1616
expected_outputs:
17-
out_0_0:
17+
out_0:
1818
SecretInteger: "294912"
19-
out_1_0:
19+
out_1:
2020
SecretInteger: "294912"
21-
out_2_0:
21+
out_2:
2222
SecretInteger: "78643"
23-
out_3_0:
23+
out_3:
2424
SecretInteger: "78643"
25-
out_4_0:
25+
out_4:
2626
SecretInteger: "1"
27-
out_5_0:
27+
out_5:
2828
SecretInteger: "1"
29-
out_6_0:
29+
out_6:
3030
SecretInteger: "1"
31-
out_7_0:
31+
out_7:
3232
SecretInteger: "1"
33-
out_8_0:
33+
out_8:
3434
Integer: "131072"
35-
out_9_0:
35+
out_9:
3636
Integer: "65536"
37-
out_10_0:
37+
out_10:
3838
Integer: "1"
39-
out_11_0:
39+
out_11:
4040
SecretInteger: "1"
41-
out_12_0:
41+
out_12:
4242
SecretUnsignedInteger: "1"

tests/nada-tests/tests/sum.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ inputs:
1010
SecretInteger: "3"
1111
public_variables: {}
1212
expected_outputs:
13-
my_output_0:
13+
my_output:
1414
SecretInteger: "9"

tests/nada-tests/tests/supported_operations.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ inputs:
8080
SecretInteger: "1"
8181
public_variables: {}
8282
expected_outputs:
83-
out_0_0:
83+
out_0:
8484
SecretInteger: 1
85-
out_1_0:
85+
out_1:
8686
SecretInteger: 362880
87-
out_2_0:
87+
out_2:
8888
SecretInteger: 31
89-
out_3_0:
89+
out_3:
9090
SecretInteger: 13
91-
out_4_0:
91+
out_4:
9292
SecretInteger: 21
93-
out_5_0:
93+
out_5:
9494
Integer: 42

tests/nada-tests/tests/supported_operations_return_types.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ inputs:
8080
SecretInteger: "1"
8181
public_variables: {}
8282
expected_outputs:
83-
out_0_0:
83+
out_0:
8484
SecretInteger: 1
85-
out_1_0:
85+
out_1:
8686
SecretInteger: 362880
87-
out_2_0:
87+
out_2:
8888
SecretInteger: 31
89-
out_3_0:
89+
out_3:
9090
SecretInteger: 13
91-
out_4_0:
91+
out_4:
9292
SecretInteger: 21
93-
out_5_0:
93+
out_5:
9494
Integer: 42

0 commit comments

Comments
 (0)