@@ -48,7 +48,7 @@ def if_else(
48
48
self ,
49
49
arg_0 : Union [_NadaType , "SecretRational" , "Rational" ],
50
50
arg_1 : Union [_NadaType , "SecretRational" , "Rational" ],
51
- ) -> Union [SecretInteger , SecretUnsignedInteger ]:
51
+ ) -> Union [SecretInteger , SecretUnsignedInteger , "SecretRational" ]:
52
52
"""
53
53
If-else logic. If the boolean is True, arg_0 is returned. If not, arg_1 is returned.
54
54
@@ -61,7 +61,7 @@ def if_else(
61
61
TypeError: Raised when invalid operation is called.
62
62
63
63
Returns:
64
- Union[SecretInteger, SecretUnsignedInteger]: Return value.
64
+ Union[SecretInteger, SecretUnsignedInteger, "SecretRational" ]: Return value.
65
65
"""
66
66
first_arg = arg_0
67
67
second_arg = arg_1
@@ -84,8 +84,7 @@ def if_else(
84
84
if isinstance (arg_0 , (SecretRational , Rational )):
85
85
# If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
86
86
return SecretRational (result , arg_0 .log_scale , is_scaled = True )
87
- else :
88
- return result
87
+ return result
89
88
90
89
91
90
class PublicBoolean (dsl .PublicBoolean ):
@@ -104,7 +103,14 @@ def if_else(
104
103
self ,
105
104
arg_0 : Union [_NadaType , "SecretRational" , "Rational" ],
106
105
arg_1 : Union [_NadaType , "SecretRational" , "Rational" ],
107
- ) -> Union [SecretInteger , SecretUnsignedInteger ]:
106
+ ) -> Union [
107
+ PublicInteger ,
108
+ PublicUnsignedInteger ,
109
+ SecretInteger ,
110
+ SecretUnsignedInteger ,
111
+ "Rational" ,
112
+ "SecretRational" ,
113
+ ]:
108
114
"""
109
115
If-else logic. If the boolean is True, arg_0 is returned. If not, arg_1 is returned.
110
116
@@ -117,7 +123,8 @@ def if_else(
117
123
TypeError: Raised when invalid operation is called.
118
124
119
125
Returns:
120
- Union[SecretInteger, SecretUnsignedInteger]: Return value.
126
+ Union[PublicInteger, PublicUnsignedInteger, SecretInteger,
127
+ SecretUnsignedInteger, "Rational", "SecretRational"]: Return value.
121
128
"""
122
129
first_arg = arg_0
123
130
second_arg = arg_1
@@ -137,11 +144,11 @@ def if_else(
137
144
138
145
result = super ().if_else (first_arg , second_arg )
139
146
140
- if isinstance (arg_0 , (SecretRational , Rational )):
141
- # If we have a SecretBoolean, the return type will be SecretInteger, thus promoted to SecretRational
147
+ if isinstance (arg_0 , SecretRational ) or isinstance (arg_1 , SecretRational ):
148
+ return SecretRational (result , arg_0 .log_scale , is_scaled = True )
149
+ elif isinstance (arg_0 , Rational ) and isinstance (arg_1 , Rational ):
142
150
return Rational (result , arg_0 .log_scale , is_scaled = True )
143
- else :
144
- return result
151
+ return result
145
152
146
153
147
154
class Rational :
@@ -545,7 +552,9 @@ def __lt__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
545
552
"""
546
553
if self .log_scale != other .log_scale :
547
554
raise ValueError ("Cannot compare values with different scales." )
548
- return self .value < other .value
555
+ if isinstance (other , SecretRational ):
556
+ return SecretBoolean (self .value < other .value )
557
+ return PublicBoolean (self .value < other .value )
549
558
550
559
def __gt__ (self , other : _NadaRational ) -> Union [PublicBoolean , SecretBoolean ]:
551
560
"""
@@ -562,7 +571,9 @@ def __gt__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
562
571
"""
563
572
if self .log_scale != other .log_scale :
564
573
raise ValueError ("Cannot compare values with different scales." )
565
- return self .value > other .value
574
+ if isinstance (other , SecretRational ):
575
+ return SecretBoolean (self .value > other .value )
576
+ return PublicBoolean (self .value > other .value )
566
577
567
578
def __le__ (self , other : _NadaRational ) -> Union [PublicBoolean , SecretBoolean ]:
568
579
"""
@@ -579,7 +590,9 @@ def __le__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
579
590
"""
580
591
if self .log_scale != other .log_scale :
581
592
raise ValueError ("Cannot compare values with different scales." )
582
- return self .value <= other .value
593
+ if isinstance (other , SecretRational ):
594
+ return SecretBoolean (self .value <= other .value )
595
+ return PublicBoolean (self .value <= other .value )
583
596
584
597
def __ge__ (self , other : _NadaRational ) -> Union [PublicBoolean , SecretBoolean ]:
585
598
"""
@@ -596,7 +609,9 @@ def __ge__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
596
609
"""
597
610
if self .log_scale != other .log_scale :
598
611
raise ValueError ("Cannot compare values with different scales." )
599
- return self .value >= other .value
612
+ if isinstance (other , SecretRational ):
613
+ return SecretBoolean (self .value >= other .value )
614
+ return PublicBoolean (self .value >= other .value )
600
615
601
616
def __eq__ (self , other : _NadaRational ) -> Union [PublicBoolean , SecretBoolean ]:
602
617
"""
@@ -613,7 +628,9 @@ def __eq__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
613
628
"""
614
629
if self .log_scale != other .log_scale :
615
630
raise ValueError ("Cannot compare values with different scales." )
616
- return self .value == other .value
631
+ if isinstance (other , SecretRational ):
632
+ return SecretBoolean (self .value == other .value )
633
+ return PublicBoolean (self .value == other .value )
617
634
618
635
def __ne__ (self , other : _NadaRational ) -> Union [PublicBoolean , SecretBoolean ]:
619
636
"""
@@ -630,7 +647,9 @@ def __ne__(self, other: _NadaRational) -> Union[PublicBoolean, SecretBoolean]:
630
647
"""
631
648
if self .log_scale != other .log_scale :
632
649
raise ValueError ("Cannot compare values with different scales." )
633
- return SecretBoolean (self .value != other .value )
650
+ if isinstance (other , SecretRational ):
651
+ return SecretBoolean (self .value != other .value )
652
+ return PublicBoolean (self .value != other .value )
634
653
635
654
def rescale_up (self , log_scale : int = None ) -> "Rational" :
636
655
"""
@@ -1287,6 +1306,9 @@ def rational(
1287
1306
Returns:
1288
1307
Rational: Instantiated Rational object.
1289
1308
"""
1309
+ if value == 0 : # no use in rescaling 0
1310
+ return Rational (Integer (0 ), is_scaled = True )
1311
+
1290
1312
if log_scale is None :
1291
1313
log_scale = get_log_scale ()
1292
1314
0 commit comments