29
29
BinaryCheck = Callable [[float , float ], bool ]
30
30
31
31
32
- def make_eq (v : float ) -> UnaryCheck :
32
+ def make_strict_eq (v : float ) -> UnaryCheck :
33
33
if math .isnan (v ):
34
34
return math .isnan
35
35
if v == 0 :
@@ -38,14 +38,14 @@ def make_eq(v: float) -> UnaryCheck:
38
38
else :
39
39
return ph .is_neg_zero
40
40
41
- def eq (i : float ) -> bool :
41
+ def strict_eq (i : float ) -> bool :
42
42
return i == v
43
43
44
- return eq
44
+ return strict_eq
45
45
46
46
47
47
def make_neq (v : float ) -> UnaryCheck :
48
- eq = make_eq (v )
48
+ eq = make_strict_eq (v )
49
49
50
50
def neq (i : float ) -> bool :
51
51
return not eq (i )
@@ -154,7 +154,8 @@ def parse_inline_code(inline_code: str) -> float:
154
154
raise ValueParseError (inline_code )
155
155
156
156
157
- r_not = re .compile ("not (?:equal to )?(.+)" )
157
+ r_not = re .compile ("not (.+)" )
158
+ r_equal_to = re .compile (f"equal to { r_code .pattern } " )
158
159
r_array_element = re .compile (r"``([+-]?)x([12])_i``" )
159
160
r_either_code = re .compile (f"either { r_code .pattern } or { r_code .pattern } " )
160
161
r_gt = re .compile (f"greater than { r_code .pattern } " )
@@ -217,9 +218,6 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
217
218
218
219
219
220
def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , FromDtypeFunc ]:
220
- if "equal to" in cond_str :
221
- raise ValueParseError (cond_str ) # TODO
222
-
223
221
if m := r_not .match (cond_str ):
224
222
cond_str = m .group (1 )
225
223
not_cond = True
@@ -232,10 +230,15 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
232
230
strat = None
233
231
if m := r_code .match (cond_str ):
234
232
value = parse_value (m .group (1 ))
235
- cond = make_eq (value )
233
+ cond = make_strict_eq (value )
236
234
expr_template = "{} == " + m .group (1 )
237
235
if not not_cond :
238
236
strat = st .just (value )
237
+ elif m := r_equal_to .match (cond_str ):
238
+ value = parse_value (m .group (1 ))
239
+ assert not math .isnan (value ) # sanity check
240
+ cond = lambda i : i == value
241
+ expr_template = "{} == " + m .group (1 )
239
242
elif m := r_gt .match (cond_str ):
240
243
value = parse_value (m .group (1 ))
241
244
cond = make_gt (value )
@@ -251,7 +254,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
251
254
elif m := r_either_code .match (cond_str ):
252
255
v1 = parse_value (m .group (1 ))
253
256
v2 = parse_value (m .group (2 ))
254
- cond = make_or (make_eq (v1 ), make_eq (v2 ))
257
+ cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
255
258
expr_template = "{} == " + m .group (1 ) + " or {} == " + m .group (2 )
256
259
if not not_cond :
257
260
strat = st .sampled_from ([v1 , v2 ])
@@ -334,7 +337,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
334
337
def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
335
338
if m := r_code .match (result_str ):
336
339
value = parse_value (m .group (1 ))
337
- check_result = make_eq (value ) # type: ignore
340
+ check_result = make_strict_eq (value ) # type: ignore
338
341
expr = m .group (1 )
339
342
elif m := r_approx_value .match (result_str ):
340
343
value = parse_value (m .group (1 ))
@@ -573,13 +576,13 @@ def make_eq_other_input_cond(
573
576
if eq_to == BinaryCondArg .FIRST :
574
577
575
578
def cond (i1 : float , i2 : float ) -> bool :
576
- eq = make_eq (input_wrapper (i1 ))
579
+ eq = make_strict_eq (input_wrapper (i1 ))
577
580
return eq (i2 )
578
581
579
582
elif eq_to == BinaryCondArg .SECOND :
580
583
581
584
def cond (i1 : float , i2 : float ) -> bool :
582
- eq = make_eq (input_wrapper (i2 ))
585
+ eq = make_strict_eq (input_wrapper (i2 ))
583
586
return eq (i1 )
584
587
585
588
else :
@@ -599,13 +602,13 @@ def make_eq_input_check_result(
599
602
if eq_to == BinaryCondArg .FIRST :
600
603
601
604
def check_result (i1 : float , i2 : float , result : float ) -> bool :
602
- eq = make_eq (input_wrapper (i1 ))
605
+ eq = make_strict_eq (input_wrapper (i1 ))
603
606
return eq (result )
604
607
605
608
elif eq_to == BinaryCondArg .SECOND :
606
609
607
610
def check_result (i1 : float , i2 : float , result : float ) -> bool :
608
- eq = make_eq (input_wrapper (i2 ))
611
+ eq = make_strict_eq (input_wrapper (i2 ))
609
612
return eq (result )
610
613
611
614
else :
0 commit comments