Skip to content

Commit a7fabc6

Browse files
manel1874jfdreis
andauthored
feat: improve tests for shuffle feature (#72)
Co-authored-by: jfdreis <josevtnreis@gmail.com>
1 parent 26d06a7 commit a7fabc6

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

tests/nada-tests/src/shuffle.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,105 @@
22

33
from typing import List
44

5-
from nada_dsl import Output, PublicInteger, SecretInteger
5+
from nada_dsl import Integer, Output, PublicInteger, SecretInteger
66

7+
import numpy as np
78
import nada_numpy as na
8-
# Step 0: Nada Numpy is imported with this line
99
from nada_numpy import shuffle
1010

11+
def bool_to_int(bool):
12+
"""Casting bool to int"""
13+
return bool.if_else(Integer(0), Integer(1))
14+
15+
def count(vec, element):
16+
"""
17+
Counts the number of times element is in vec.
18+
"""
19+
20+
result = Integer(0)
21+
for e in vec:
22+
b = ~(element == e)
23+
int_b = bool_to_int(b)
24+
result += int_b
25+
26+
return result
27+
1128

1229
def nada_main() -> List[Output]:
1330

31+
n = 8
32+
1433
parties = na.parties(2)
15-
a = na.array([8], parties[0], "A", na.Rational)
16-
b = na.array([8], parties[0], "B", na.SecretRational)
17-
c = na.array([8], parties[0], "C", PublicInteger)
18-
d = na.array([8], parties[0], "D", SecretInteger)
34+
a = na.array([n], parties[0], "A", na.Rational)
35+
b = na.array([n], parties[0], "B", na.SecretRational)
36+
c = na.array([n], parties[0], "C", PublicInteger)
37+
d = na.array([n], parties[0], "D", SecretInteger)
1938

39+
2040
# As a function
2141

2242
shuffled_a = shuffle(a)
2343
shuffled_b = shuffle(b)
2444
shuffled_c = shuffle(c)
2545
shuffled_d = shuffle(d)
2646

47+
# 1. Show shuffle works for Rational, SecretRational, PublicInteger and SecretInteger
2748
result_a = shuffled_a - shuffled_a
2849
result_b = shuffled_b - shuffled_b
2950
result_c = shuffled_c - shuffled_c
3051
result_d = shuffled_d - shuffled_d
3152

53+
# 2. Randomness: show at least one element is in a different position
54+
# true if equal
55+
diff_position_bool = [a[i] == shuffled_a[i] for i in range(n)]
56+
# cast to int (true -> 0 and false -> 1)
57+
diff_position = np.array([bool_to_int(element) for element in diff_position_bool])
58+
# add them
59+
sum = diff_position.sum()
60+
# if all are equal => all are 0 => sum is zero
61+
at_least_one_diff_element = sum > Integer(0)
62+
63+
# 3. Show elements are preserved:
64+
check = Integer(0)
65+
for ai in a:
66+
nr_ai_in_shufled_a = count(shuffled_a, ai)
67+
nr_ai_in_a = count(a, ai)
68+
check += bool_to_int(nr_ai_in_shufled_a == nr_ai_in_a)
69+
elements_are_preserved = check == Integer(0)
70+
71+
3272
# As a method
3373

3474
shuffled_method_a = a.shuffle()
3575
shuffled_method_b = b.shuffle()
3676
shuffled_method_c = c.shuffle()
3777
shuffled_method_d = d.shuffle()
3878

79+
# 1. Show shuffle works for Rational, SecretRational, PublicInteger and SecretInteger
3980
result_method_a = shuffled_method_a - shuffled_method_a
4081
result_method_b = shuffled_method_b - shuffled_method_b
4182
result_method_c = shuffled_method_c - shuffled_method_c
4283
result_method_d = shuffled_method_d - shuffled_method_d
4384

85+
# 2. Randomness: show at least one element is in a different position
86+
# true if equal
87+
diff_position_bool_method = [a[i] == shuffled_method_a[i] for i in range(n)]
88+
# cast to int (true -> 0 and false -> 1)
89+
diff_position_method = np.array([bool_to_int(element) for element in diff_position_bool_method])
90+
# add them
91+
sum_method = diff_position_method.sum()
92+
# if all are equal => all are 0 => sum is zero
93+
at_least_one_diff_element_method = sum_method > Integer(0)
94+
95+
# 3. Show elements are preserved:
96+
check = Integer(0)
97+
for ai in a:
98+
nr_ai_in_shufled_a = count(shuffled_method_a, ai)
99+
nr_ai_in_a = count(a, ai)
100+
check += bool_to_int(nr_ai_in_shufled_a == nr_ai_in_a)
101+
elements_are_preserved_method = check == Integer(0)
102+
103+
44104
return (
45105
na.output(result_a, parties[1], "my_output_a")
46106
+ na.output(result_b, parties[1], "my_output_b")
@@ -50,4 +110,8 @@ def nada_main() -> List[Output]:
50110
+ na.output(result_method_b, parties[1], "my_output_method_b")
51111
+ na.output(result_method_c, parties[1], "my_output_method_c")
52112
+ na.output(result_method_d, parties[1], "my_output_method_d")
113+
+ na.output(at_least_one_diff_element, parties[1], "at_least_one_diff_element")
114+
+ na.output(at_least_one_diff_element_method, parties[1], "at_least_one_diff_element_method")
115+
+ na.output(elements_are_preserved, parties[1], "elements_are_preserved")
116+
+ na.output(elements_are_preserved_method, parties[1], "elements_are_preserved_method")
53117
)

tests/nada-tests/tests/shuffle.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ expected_outputs:
9898
my_output_method_d_5: 0
9999
my_output_method_d_6: 0
100100
my_output_method_d_7: 0
101+
at_least_one_diff_element: true
102+
at_least_one_diff_element_method: true
103+
elements_are_preserved: true
104+
elements_are_preserved_method: true

0 commit comments

Comments
 (0)