2
2
3
3
import numpy as np
4
4
5
+ import nada_dsl as dsl
6
+
5
7
from nada_dsl import (
6
8
Input ,
7
9
Party ,
8
10
UnsignedInteger ,
9
11
Integer ,
10
12
NadaType ,
11
13
SecretInteger ,
12
- SecretBoolean ,
13
- PublicBoolean ,
14
+ SecretUnsignedInteger ,
14
15
PublicInteger ,
16
+ PublicUnsignedInteger ,
15
17
)
18
+
19
+
16
20
from typing import Union
17
21
18
22
25
29
26
30
_NadaRational = Union ["Rational" , "SecretRational" ]
27
31
32
+ _NadaType = Union [
33
+ Integer ,
34
+ PublicInteger ,
35
+ PublicUnsignedInteger ,
36
+ SecretInteger ,
37
+ SecretUnsignedInteger ,
38
+ UnsignedInteger ,
39
+ ]
40
+
41
+
42
+ class SecretBoolean (dsl .SecretBoolean ):
43
+
44
+ def __init__ (self , value ):
45
+ super ().__init__ (value .inner )
46
+
47
+ def if_else (
48
+ self : dsl .SecretBoolean ,
49
+ arg_0 : _NadaType | "SecretRational" | "Rational" ,
50
+ arg_1 : _NadaType | "SecretRational" | "Rational" ,
51
+ ) -> Union [SecretInteger , SecretUnsignedInteger ]:
52
+ first_arg = arg_0
53
+ second_arg = arg_1
54
+ if isinstance (arg_0 , (SecretRational , Rational )) and isinstance (
55
+ arg_1 , (SecretRational , Rational )
56
+ ):
57
+ # Both are SecretRational or Rational objects
58
+ if arg_0 .log_scale != arg_1 .log_scale :
59
+ raise ValueError ("Cannot output values with different scales." )
60
+ first_arg = arg_0 .value
61
+ second_arg = arg_1 .value
62
+ elif isinstance (arg_0 , (Rational , SecretRational )) or isinstance (
63
+ arg_1 , (Rational , SecretRational )
64
+ ):
65
+ # Both are SecretRational or Rational objects
66
+ raise TypeError (f"Invalid operation: { self } .IfElse({ arg_0 } , { arg_1 } )" )
67
+
68
+ result = super ().if_else (first_arg , second_arg )
69
+
70
+ if isinstance (arg_0 , (SecretRational , Rational )):
71
+ # If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
72
+ return SecretRational .from_parts (result , arg_0 .log_scale )
73
+ else :
74
+ return result
75
+
76
+
77
+ class PublicBoolean (dsl .PublicBoolean ):
78
+
79
+ def __init__ (self , value ):
80
+ super ().__init__ (value .inner )
81
+
82
+ def if_else (
83
+ self : dsl .SecretBoolean ,
84
+ arg_0 : _NadaType | "SecretRational" | "Rational" ,
85
+ arg_1 : _NadaType | "SecretRational" | "Rational" ,
86
+ ) -> Union [SecretInteger , SecretUnsignedInteger ]:
87
+ first_arg = arg_0
88
+ second_arg = arg_1
89
+ if isinstance (arg_0 , (SecretRational , Rational )) and isinstance (
90
+ arg_1 , (SecretRational , Rational )
91
+ ):
92
+ # Both are SecretRational or Rational objects
93
+ if arg_0 .log_scale != arg_1 .log_scale :
94
+ raise ValueError ("Cannot output values with different scales." )
95
+ first_arg = arg_0 .value
96
+ second_arg = arg_1 .value
97
+ elif isinstance (arg_0 , (Rational , SecretRational )) or isinstance (
98
+ arg_1 , (Rational , SecretRational )
99
+ ):
100
+ # Both are SecretRational or Rational objects but of different type
101
+ raise TypeError (f"Invalid operation: { self } .IfElse({ arg_0 } , { arg_1 } )" )
102
+
103
+ result = super ().if_else (first_arg , second_arg )
104
+
105
+ if isinstance (arg_0 , (SecretRational , Rational )):
106
+ # If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
107
+ return Rational .from_parts (result , arg_0 .log_scale )
108
+ else :
109
+ return result
110
+
28
111
29
112
class RationalConfig (object ):
30
113
@@ -828,7 +911,7 @@ def __lt__(self, other: _NadaRational) -> SecretBoolean:
828
911
"""
829
912
if self .log_scale != other .log_scale :
830
913
raise ValueError ("Cannot compare values with different scales." )
831
- return self .value < other .value
914
+ return SecretBoolean ( self .value < other .value )
832
915
833
916
def __gt__ (self , other : _NadaRational ) -> SecretBoolean :
834
917
"""Check if this SecretRational is greater than another.
@@ -844,7 +927,7 @@ def __gt__(self, other: _NadaRational) -> SecretBoolean:
844
927
"""
845
928
if self .log_scale != other .log_scale :
846
929
raise ValueError ("Cannot compare values with different scales." )
847
- return self .value > other .value
930
+ return SecretBoolean ( self .value > other .value )
848
931
849
932
def __le__ (self , other : _NadaRational ) -> SecretBoolean :
850
933
"""Check if this SecretRational is less than or equal to another.
@@ -860,7 +943,7 @@ def __le__(self, other: _NadaRational) -> SecretBoolean:
860
943
"""
861
944
if self .log_scale != other .log_scale :
862
945
raise ValueError ("Cannot compare values with different scales." )
863
- return self .value <= other .value
946
+ return SecretBoolean ( self .value <= other .value )
864
947
865
948
def __ge__ (self , other : _NadaRational ) -> SecretBoolean :
866
949
"""Check if this SecretRational is greater than or equal to another.
@@ -876,7 +959,7 @@ def __ge__(self, other: _NadaRational) -> SecretBoolean:
876
959
"""
877
960
if self .log_scale != other .log_scale :
878
961
raise ValueError ("Cannot compare values with different scales." )
879
- return self .value >= other .value
962
+ return SecretBoolean ( self .value >= other .value )
880
963
881
964
def __eq__ (self , other : _NadaRational ) -> SecretBoolean :
882
965
"""Check if this SecretRational is equal to another.
@@ -892,7 +975,7 @@ def __eq__(self, other: _NadaRational) -> SecretBoolean:
892
975
"""
893
976
if self .log_scale != other .log_scale :
894
977
raise ValueError ("Cannot compare values with different scales." )
895
- return self .value == other .value
978
+ return SecretBoolean ( self .value == other .value )
896
979
897
980
def __ne__ (self , other : _NadaRational ) -> SecretBoolean :
898
981
"""Check if this SecretRational is not equal to another.
0 commit comments