Skip to content

Commit 252eb34

Browse files
authored
feat: interoperability of booleans with rationals for if_else (#18)
1 parent 69c7c1d commit 252eb34

File tree

8 files changed

+393
-10
lines changed

8 files changed

+393
-10
lines changed

nada_algebra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""This is the __init__.py module"""
22

33
from nada_algebra.array import NadaArray
4+
from nada_algebra.types import RationalConfig, Rational, SecretRational
45
from nada_algebra.funcs import *
5-
from nada_algebra.types import Rational, SecretRational

nada_algebra/types.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@
22

33
import numpy as np
44

5+
import nada_dsl as dsl
6+
57
from nada_dsl import (
68
Input,
79
Party,
810
UnsignedInteger,
911
Integer,
1012
NadaType,
1113
SecretInteger,
12-
SecretBoolean,
13-
PublicBoolean,
14+
SecretUnsignedInteger,
1415
PublicInteger,
16+
PublicUnsignedInteger,
1517
)
18+
19+
1620
from typing import Union
1721

1822

@@ -25,6 +29,85 @@
2529

2630
_NadaRational = Union["Rational", "SecretRational"]
2731

32+
_NadaType = Union[
33+
Integer,
34+
PublicInteger,
35+
PublicUnsignedInteger,
36+
SecretInteger,
37+
SecretUnsignedInteger,
38+
UnsignedInteger,
39+
]
40+
41+
42+
class SecretBoolean(dsl.SecretBoolean):
43+
44+
def __init__(self, value):
45+
super().__init__(value.inner)
46+
47+
def if_else(
48+
self: dsl.SecretBoolean,
49+
arg_0: _NadaType | "SecretRational" | "Rational",
50+
arg_1: _NadaType | "SecretRational" | "Rational",
51+
) -> Union[SecretInteger, SecretUnsignedInteger]:
52+
first_arg = arg_0
53+
second_arg = arg_1
54+
if isinstance(arg_0, (SecretRational, Rational)) and isinstance(
55+
arg_1, (SecretRational, Rational)
56+
):
57+
# Both are SecretRational or Rational objects
58+
if arg_0.log_scale != arg_1.log_scale:
59+
raise ValueError("Cannot output values with different scales.")
60+
first_arg = arg_0.value
61+
second_arg = arg_1.value
62+
elif isinstance(arg_0, (Rational, SecretRational)) or isinstance(
63+
arg_1, (Rational, SecretRational)
64+
):
65+
# Both are SecretRational or Rational objects
66+
raise TypeError(f"Invalid operation: {self}.IfElse({arg_0}, {arg_1})")
67+
68+
result = super().if_else(first_arg, second_arg)
69+
70+
if isinstance(arg_0, (SecretRational, Rational)):
71+
# If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
72+
return SecretRational.from_parts(result, arg_0.log_scale)
73+
else:
74+
return result
75+
76+
77+
class PublicBoolean(dsl.PublicBoolean):
78+
79+
def __init__(self, value):
80+
super().__init__(value.inner)
81+
82+
def if_else(
83+
self: dsl.SecretBoolean,
84+
arg_0: _NadaType | "SecretRational" | "Rational",
85+
arg_1: _NadaType | "SecretRational" | "Rational",
86+
) -> Union[SecretInteger, SecretUnsignedInteger]:
87+
first_arg = arg_0
88+
second_arg = arg_1
89+
if isinstance(arg_0, (SecretRational, Rational)) and isinstance(
90+
arg_1, (SecretRational, Rational)
91+
):
92+
# Both are SecretRational or Rational objects
93+
if arg_0.log_scale != arg_1.log_scale:
94+
raise ValueError("Cannot output values with different scales.")
95+
first_arg = arg_0.value
96+
second_arg = arg_1.value
97+
elif isinstance(arg_0, (Rational, SecretRational)) or isinstance(
98+
arg_1, (Rational, SecretRational)
99+
):
100+
# Both are SecretRational or Rational objects but of different type
101+
raise TypeError(f"Invalid operation: {self}.IfElse({arg_0}, {arg_1})")
102+
103+
result = super().if_else(first_arg, second_arg)
104+
105+
if isinstance(arg_0, (SecretRational, Rational)):
106+
# If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
107+
return Rational.from_parts(result, arg_0.log_scale)
108+
else:
109+
return result
110+
28111

29112
class RationalConfig(object):
30113

@@ -828,7 +911,7 @@ def __lt__(self, other: _NadaRational) -> SecretBoolean:
828911
"""
829912
if self.log_scale != other.log_scale:
830913
raise ValueError("Cannot compare values with different scales.")
831-
return self.value < other.value
914+
return SecretBoolean(self.value < other.value)
832915

833916
def __gt__(self, other: _NadaRational) -> SecretBoolean:
834917
"""Check if this SecretRational is greater than another.
@@ -844,7 +927,7 @@ def __gt__(self, other: _NadaRational) -> SecretBoolean:
844927
"""
845928
if self.log_scale != other.log_scale:
846929
raise ValueError("Cannot compare values with different scales.")
847-
return self.value > other.value
930+
return SecretBoolean(self.value > other.value)
848931

849932
def __le__(self, other: _NadaRational) -> SecretBoolean:
850933
"""Check if this SecretRational is less than or equal to another.
@@ -860,7 +943,7 @@ def __le__(self, other: _NadaRational) -> SecretBoolean:
860943
"""
861944
if self.log_scale != other.log_scale:
862945
raise ValueError("Cannot compare values with different scales.")
863-
return self.value <= other.value
946+
return SecretBoolean(self.value <= other.value)
864947

865948
def __ge__(self, other: _NadaRational) -> SecretBoolean:
866949
"""Check if this SecretRational is greater than or equal to another.
@@ -876,7 +959,7 @@ def __ge__(self, other: _NadaRational) -> SecretBoolean:
876959
"""
877960
if self.log_scale != other.log_scale:
878961
raise ValueError("Cannot compare values with different scales.")
879-
return self.value >= other.value
962+
return SecretBoolean(self.value >= other.value)
880963

881964
def __eq__(self, other: _NadaRational) -> SecretBoolean:
882965
"""Check if this SecretRational is equal to another.
@@ -892,7 +975,7 @@ def __eq__(self, other: _NadaRational) -> SecretBoolean:
892975
"""
893976
if self.log_scale != other.log_scale:
894977
raise ValueError("Cannot compare values with different scales.")
895-
return self.value == other.value
978+
return SecretBoolean(self.value == other.value)
896979

897980
def __ne__(self, other: _NadaRational) -> SecretBoolean:
898981
"""Check if this SecretRational is not equal to another.

0 commit comments

Comments
 (0)