Skip to content

Commit 7910167

Browse files
Bugfix/type annotations (#21)
Fix rational if_else logic
1 parent 6902567 commit 7910167

File tree

6 files changed

+149
-33
lines changed

6 files changed

+149
-33
lines changed

nada_algebra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from nada_algebra.types import (
66
Rational,
77
SecretRational,
8+
PublicBoolean,
9+
SecretBoolean,
810
public_rational,
911
rational,
1012
secret_rational,

nada_algebra/array.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,13 @@ def dtype(self) -> Type:
614614
if self.empty:
615615
return NoneType
616616
return type(self.inner.item(0))
617+
618+
@property
619+
def is_rational(self) -> bool:
620+
"""
621+
Returns whether or not the Array's type is a rational.
622+
623+
Returns:
624+
bool: Boolean output.
625+
"""
626+
return self.dtype in (Rational, SecretRational)

nada_algebra/types.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def if_else(
4848
self,
4949
arg_0: Union[_NadaType, "SecretRational", "Rational"],
5050
arg_1: Union[_NadaType, "SecretRational", "Rational"],
51-
) -> Union[SecretInteger, SecretUnsignedInteger]:
51+
) -> Union[SecretInteger, SecretUnsignedInteger, "SecretRational"]:
5252
"""
5353
If-else logic. If the boolean is True, arg_0 is returned. If not, arg_1 is returned.
5454
@@ -61,7 +61,7 @@ def if_else(
6161
TypeError: Raised when invalid operation is called.
6262
6363
Returns:
64-
Union[SecretInteger, SecretUnsignedInteger]: Return value.
64+
Union[SecretInteger, SecretUnsignedInteger, "SecretRational"]: Return value.
6565
"""
6666
first_arg = arg_0
6767
second_arg = arg_1
@@ -84,8 +84,7 @@ def if_else(
8484
if isinstance(arg_0, (SecretRational, Rational)):
8585
# If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
8686
return SecretRational(result, arg_0.log_scale, is_scaled=True)
87-
else:
88-
return result
87+
return result
8988

9089

9190
class PublicBoolean(dsl.PublicBoolean):
@@ -104,7 +103,14 @@ def if_else(
104103
self,
105104
arg_0: Union[_NadaType, "SecretRational", "Rational"],
106105
arg_1: Union[_NadaType, "SecretRational", "Rational"],
107-
) -> Union[SecretInteger, SecretUnsignedInteger]:
106+
) -> Union[
107+
PublicInteger,
108+
PublicUnsignedInteger,
109+
SecretInteger,
110+
SecretUnsignedInteger,
111+
"Rational",
112+
"SecretRational",
113+
]:
108114
"""
109115
If-else logic. If the boolean is True, arg_0 is returned. If not, arg_1 is returned.
110116
@@ -117,7 +123,8 @@ def if_else(
117123
TypeError: Raised when invalid operation is called.
118124
119125
Returns:
120-
Union[SecretInteger, SecretUnsignedInteger]: Return value.
126+
Union[PublicInteger, PublicUnsignedInteger, SecretInteger,
127+
SecretUnsignedInteger, "Rational", "SecretRational"]: Return value.
121128
"""
122129
first_arg = arg_0
123130
second_arg = arg_1
@@ -137,11 +144,11 @@ def if_else(
137144

138145
result = super().if_else(first_arg, second_arg)
139146

140-
if isinstance(arg_0, (SecretRational, Rational)):
141-
# If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
147+
if isinstance(arg_0, SecretRational) or isinstance(arg_1, SecretRational):
148+
return SecretRational(result, arg_0.log_scale, is_scaled=True)
149+
elif isinstance(arg_0, Rational) and isinstance(arg_1, Rational):
142150
return Rational(result, arg_0.log_scale, is_scaled=True)
143-
else:
144-
return result
151+
return result
145152

146153

147154
class Rational:
@@ -545,7 +552,9 @@ def __lt__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
545552
"""
546553
if self.log_scale != other.log_scale:
547554
raise ValueError("Cannot compare values with different scales.")
548-
return self.value < other.value
555+
if isinstance(other, SecretRational):
556+
return SecretBoolean(self.value < other.value)
557+
return PublicBoolean(self.value < other.value)
549558

550559
def __gt__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
551560
"""
@@ -562,7 +571,9 @@ def __gt__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
562571
"""
563572
if self.log_scale != other.log_scale:
564573
raise ValueError("Cannot compare values with different scales.")
565-
return self.value > other.value
574+
if isinstance(other, SecretRational):
575+
return SecretBoolean(self.value > other.value)
576+
return PublicBoolean(self.value > other.value)
566577

567578
def __le__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
568579
"""
@@ -579,7 +590,9 @@ def __le__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
579590
"""
580591
if self.log_scale != other.log_scale:
581592
raise ValueError("Cannot compare values with different scales.")
582-
return self.value <= other.value
593+
if isinstance(other, SecretRational):
594+
return SecretBoolean(self.value <= other.value)
595+
return PublicBoolean(self.value <= other.value)
583596

584597
def __ge__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
585598
"""
@@ -596,7 +609,9 @@ def __ge__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
596609
"""
597610
if self.log_scale != other.log_scale:
598611
raise ValueError("Cannot compare values with different scales.")
599-
return self.value >= other.value
612+
if isinstance(other, SecretRational):
613+
return SecretBoolean(self.value >= other.value)
614+
return PublicBoolean(self.value >= other.value)
600615

601616
def __eq__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
602617
"""
@@ -613,7 +628,9 @@ def __eq__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
613628
"""
614629
if self.log_scale != other.log_scale:
615630
raise ValueError("Cannot compare values with different scales.")
616-
return self.value == other.value
631+
if isinstance(other, SecretRational):
632+
return SecretBoolean(self.value == other.value)
633+
return PublicBoolean(self.value == other.value)
617634

618635
def __ne__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
619636
"""
@@ -630,7 +647,9 @@ def __ne__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
630647
"""
631648
if self.log_scale != other.log_scale:
632649
raise ValueError("Cannot compare values with different scales.")
633-
return SecretBoolean(self.value != other.value)
650+
if isinstance(other, SecretRational):
651+
return SecretBoolean(self.value != other.value)
652+
return PublicBoolean(self.value != other.value)
634653

635654
def rescale_up(self, log_scale: int = None) -> "Rational":
636655
"""
@@ -1287,6 +1306,9 @@ def rational(
12871306
Returns:
12881307
Rational: Instantiated Rational object.
12891308
"""
1309+
if value == 0: # no use in rescaling 0
1310+
return Rational(Integer(0), is_scaled=True)
1311+
12901312
if log_scale is None:
12911313
log_scale = get_log_scale()
12921314

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "nada-algebra"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = "Nada-Algebra is a Python library designed for algebraic operations on NumPy-like array objects on top of Nada DSL and Nillion Network."
55
authors = ["José Cabrero-Holgueras <jose.cabrero@nillion.com>"]
66
readme = "README.md"
Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from nada_dsl import *
23
import nada_algebra as na
34

@@ -8,16 +9,77 @@ def nada_main():
89
a = na.secret_rational("A", parties[0])
910
b = na.secret_rational("B", parties[1])
1011
c = na.secret_rational("C", parties[2])
12+
d = SecretInteger(Input("D", parties[0]))
13+
e = SecretUnsignedInteger(Input("E", parties[0]))
1114

12-
out_0 = (a > b).if_else(na.rational(0), na.rational(1))
13-
out_1 = (a >= b).if_else(na.rational(1), na.rational(0))
14-
out_2 = (a < b).if_else(na.rational(2), na.rational(1))
15-
out_3 = (a <= b).if_else(na.rational(3), na.rational(0))
15+
out_0 = (a > b).if_else(a, b)
16+
assert isinstance(out_0, na.SecretRational), type(out_0).__name__
17+
out_1 = (a >= b).if_else(a, b)
18+
assert isinstance(out_1, na.SecretRational), type(out_1).__name__
19+
out_2 = (a < b).if_else(a, b)
20+
assert isinstance(out_2, na.SecretRational), type(out_2).__name__
21+
out_3 = (a <= b).if_else(a, b)
22+
assert isinstance(out_3, na.SecretRational), type(out_3).__name__
1623

17-
out_4 = (a > b).if_else(c, na.rational(2))
18-
out_5 = (a >= b).if_else(na.rational(2), c)
19-
out_6 = (a < b).if_else(c, na.rational(2))
20-
out_7 = (a <= b).if_else(c, na.rational(2))
24+
out_4 = (a > na.rational(1)).if_else(d, Integer(1))
25+
assert isinstance(out_4, SecretInteger), type(out_4).__name__
26+
out_5 = (a >= na.rational(1)).if_else(d, Integer(1))
27+
assert isinstance(out_5, SecretInteger), type(out_5).__name__
28+
out_6 = (a < na.rational(1)).if_else(d, Integer(1))
29+
assert isinstance(out_6, SecretInteger), type(out_6).__name__
30+
out_7 = (a <= na.rational(1)).if_else(d, Integer(1))
31+
assert isinstance(out_7, SecretInteger), type(out_7).__name__
32+
33+
out_8 = (na.rational(0) > na.rational(1)).if_else(na.rational(1), na.rational(2))
34+
assert isinstance(out_8, na.Rational), type(out_8).__name__
35+
out_9 = (na.rational(0) >= na.rational(1)).if_else(na.rational(2), na.rational(1))
36+
assert isinstance(out_9, na.Rational), type(out_9).__name__
37+
out_10 = (na.rational(0) < na.rational(1)).if_else(Integer(1), Integer(2))
38+
assert isinstance(out_10, PublicInteger), type(out_10).__name__
39+
out_11 = (na.rational(0) <= na.rational(1)).if_else(Integer(1), d)
40+
assert isinstance(out_11, SecretInteger), type(out_11).__name__
41+
out_12 = (na.rational(0) <= na.rational(1)).if_else(UnsignedInteger(1), e)
42+
assert isinstance(out_12, SecretUnsignedInteger), type(out_12).__name__
43+
out_13 = (na.rational(0) <= na.rational(1)).if_else(
44+
UnsignedInteger(1), UnsignedInteger(0)
45+
)
46+
assert isinstance(out_13, PublicUnsignedInteger), type(out_13).__name__
47+
48+
# Incompatible input types
49+
with pytest.raises(Exception):
50+
(a > Integer(1)).if_else(na.rational(0), na.rational(1))
51+
with pytest.raises(Exception):
52+
(Integer(1) > a).if_else(na.rational(0), na.rational(1))
53+
with pytest.raises(Exception):
54+
(a > d).if_else(na.rational(0), na.rational(1))
55+
with pytest.raises(Exception):
56+
(d > a).if_else(na.rational(0), na.rational(1))
57+
with pytest.raises(Exception):
58+
(na.rational(1) > Integer(1)).if_else(na.rational(0), na.rational(1))
59+
with pytest.raises(Exception):
60+
(Integer(1) > na.rational(1)).if_else(na.rational(0), na.rational(1))
61+
with pytest.raises(Exception):
62+
(na.rational(1) > d).if_else(na.rational(0), na.rational(1))
63+
with pytest.raises(Exception):
64+
(d > na.rational(1)).if_else(na.rational(0), na.rational(1))
65+
66+
# Incompatible return types
67+
with pytest.raises(Exception):
68+
(a > b).if_else(c, Integer(1))
69+
with pytest.raises(Exception):
70+
(a > b).if_else(Integer(1), c)
71+
with pytest.raises(Exception):
72+
(a > b).if_else(c, d)
73+
with pytest.raises(Exception):
74+
(a > b).if_else(d, c)
75+
with pytest.raises(Exception):
76+
(na.rational(0) > na.rational(1)).if_else(c, Integer(1))
77+
with pytest.raises(Exception):
78+
(na.rational(0) > na.rational(1)).if_else(Integer(1), c)
79+
with pytest.raises(Exception):
80+
(na.rational(0) > na.rational(1)).if_else(c, d)
81+
with pytest.raises(Exception):
82+
(na.rational(0) > na.rational(1)).if_else(d, c)
2183

2284
return (
2385
na.output(out_0, parties[2], "out_0")
@@ -28,4 +90,10 @@ def nada_main():
2890
+ na.output(out_5, parties[2], "out_5")
2991
+ na.output(out_6, parties[2], "out_6")
3092
+ na.output(out_7, parties[2], "out_7")
93+
+ na.output(out_8, parties[2], "out_8")
94+
+ na.output(out_9, parties[2], "out_9")
95+
+ na.output(out_10, parties[2], "out_10")
96+
+ na.output(out_11, parties[2], "out_11")
97+
+ na.output(out_12, parties[2], "out_12")
98+
+ na.output(out_13, parties[2], "out_13")
3199
)

tests/nada-tests/tests/rational_if_else.yaml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,35 @@ inputs:
88
SecretInteger: "294912" # 3.2
99
C:
1010
SecretInteger: "65536" # 1
11+
D:
12+
SecretInteger: "1"
13+
E:
14+
SecretUnsignedInteger: "1"
1115
public_variables: {}
1216
expected_outputs:
1317
out_0_0:
14-
SecretInteger: "65536"
18+
SecretInteger: "294912"
1519
out_1_0:
16-
SecretInteger: "0"
20+
SecretInteger: "294912"
1721
out_2_0:
18-
SecretInteger: "131072"
22+
SecretInteger: "78643"
1923
out_3_0:
20-
SecretInteger: "196608"
24+
SecretInteger: "78643"
2125
out_4_0:
22-
SecretInteger: "131072"
26+
SecretInteger: "1"
2327
out_5_0:
24-
SecretInteger: "65536"
28+
SecretInteger: "1"
2529
out_6_0:
26-
SecretInteger: "65536"
30+
SecretInteger: "1"
2731
out_7_0:
28-
SecretInteger: "65536"
32+
SecretInteger: "1"
33+
out_8_0:
34+
Integer: "131072"
35+
out_9_0:
36+
Integer: "65536"
37+
out_10_0:
38+
Integer: "1"
39+
out_11_0:
40+
SecretInteger: "1"
41+
out_12_0:
42+
SecretUnsignedInteger: "1"

0 commit comments

Comments
 (0)