@@ -311,6 +311,7 @@ def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
311
311
"""
312
312
Wraps an elements strategy as a xps.from_dtype()-like function
313
313
"""
314
+
314
315
def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
315
316
assert len (kw ) == 0 # sanity check
316
317
return strat
@@ -553,23 +554,6 @@ class UnaryCase(Case):
553
554
cond : UnaryCheck
554
555
check_result : UnaryResultCheck
555
556
556
- @classmethod
557
- def from_strings (cls , cond_str : str , result_str : str ):
558
- cond , cond_expr_template , cond_from_dtype = parse_cond (cond_str )
559
- cond_expr = cond_expr_template .replace ("{}" , "x_i" )
560
- _check_result , result_expr = parse_result (result_str )
561
-
562
- def check_result (i : float , result : float ) -> bool :
563
- return _check_result (result )
564
-
565
- return cls (
566
- cond_expr = cond_expr ,
567
- cond = cond ,
568
- cond_from_dtype = cond_from_dtype ,
569
- result_expr = result_expr ,
570
- check_result = check_result ,
571
- )
572
-
573
557
574
558
r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
575
559
r_even_int_round_case = re .compile (
@@ -578,7 +562,7 @@ def check_result(i: float, result: float) -> bool:
578
562
)
579
563
580
564
581
- def trailing_halves_from_dtype (dtype : DataType ):
565
+ def trailing_halves_from_dtype (dtype : DataType ) -> st . SearchStrategy [ float ] :
582
566
m , M = dh .dtype_ranges [dtype ]
583
567
return st .integers (math .ceil (m ) // 2 , math .floor (M ) // 2 ).map (lambda n : n * 0.5 )
584
568
@@ -594,6 +578,13 @@ def trailing_halves_from_dtype(dtype: DataType):
594
578
)
595
579
596
580
581
+ def make_unary_check_result (check_just_result : UnaryCheck ) -> UnaryResultCheck :
582
+ def check_result (i : float , result : float ) -> bool :
583
+ return check_just_result (result )
584
+
585
+ return check_result
586
+
587
+
597
588
def parse_unary_docstring (docstring : str ) -> List [UnaryCase ]:
598
589
match = r_special_cases .search (docstring )
599
590
if match is None :
@@ -608,10 +599,22 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
608
599
continue
609
600
if m := r_unary_case .search (case ):
610
601
try :
611
- case = UnaryCase .from_strings (* m .groups ())
602
+ cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
603
+ _check_result , result_expr = parse_result (m .group (2 ))
612
604
except ParseError as e :
613
605
warn (f"not machine-readable: '{ e .value } '" )
614
606
continue
607
+ cond_expr = cond_expr_template .replace ("{}" , "x_i" )
608
+ # Do not define check_result in this function's body - see
609
+ # parse_binary_case comment.
610
+ check_result = make_unary_check_result (_check_result )
611
+ case = UnaryCase (
612
+ cond_expr = cond_expr ,
613
+ cond = cond ,
614
+ cond_from_dtype = cond_from_dtype ,
615
+ result_expr = result_expr ,
616
+ check_result = check_result ,
617
+ )
615
618
cases .append (case )
616
619
elif m := r_even_int_round_case .search (case ):
617
620
cases .append (even_int_round_case )
@@ -741,7 +744,7 @@ def check_result(i1: float, i2: float, result: float) -> bool:
741
744
return check_result
742
745
743
746
744
- def make_check_result (check_just_result : UnaryCheck ) -> BinaryResultCheck :
747
+ def make_binary_check_result (check_just_result : UnaryCheck ) -> BinaryResultCheck :
745
748
def check_result (i1 : float , i2 : float , result : float ) -> bool :
746
749
return check_just_result (result )
747
750
@@ -843,12 +846,12 @@ def partial_cond(i1: float, i2: float) -> bool:
843
846
844
847
else :
845
848
unary_cond , expr_template , cond_from_dtype = parse_cond (value_str )
846
- # Do not define partial_cond via the def keyword, as one
847
- # partial_cond definition can mess up previous definitions
848
- # in the partial_conds list. This is a hard-limitation of
849
- # using local functions with the same name and that use the same
850
- # outer variables (i.e. unary_cond). Use def in a called
851
- # function avoids this problem.
849
+ # Do not define partial_cond via the def keyword or lambda
850
+ # expressions, as one partial_cond definition can mess up
851
+ # previous definitions in the partial_conds list. This is a
852
+ # hard-limitation of using local functions with the same name
853
+ # and that use the same outer variables (i.e. unary_cond). Use
854
+ # def in a called function avoids this problem.
852
855
input_wrapper = None
853
856
if m := r_input .match (input_str ):
854
857
x_no = m .group (1 )
@@ -924,7 +927,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
924
927
if result_m is None :
925
928
raise ParseError (case_m .group (2 ))
926
929
result_str = result_m .group (1 )
927
- # Like with partial_cond, do not define check_result via the def keyword
930
+ # Like with partial_cond, do not define check_result in this function's body.
928
931
if m := r_array_element .match (result_str ):
929
932
sign , x_no = m .groups ()
930
933
result_expr = f"{ sign } x{ x_no } _i"
@@ -933,7 +936,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
933
936
)
934
937
else :
935
938
_check_result , result_expr = parse_result (result_m .group (1 ))
936
- check_result = make_check_result (_check_result )
939
+ check_result = make_binary_check_result (_check_result )
937
940
938
941
cond_expr = " and " .join (partial_exprs )
939
942
0 commit comments