6
6
from dataclasses import dataclass
7
7
from decimal import ROUND_HALF_EVEN , Decimal
8
8
from enum import Enum , auto
9
- from typing import (
10
- Any ,
11
- Callable ,
12
- Dict ,
13
- List ,
14
- Match ,
15
- Optional ,
16
- Protocol ,
17
- Tuple ,
18
- )
9
+ from typing import Any , Callable , Dict , List , Optional , Protocol , Tuple
19
10
from warnings import warn
20
11
21
12
import pytest
@@ -178,7 +169,8 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
178
169
@dataclass
179
170
class BoundFromDtype (FromDtypeFunc ):
180
171
kwargs : Dict [str , Any ]
181
- filter_ : Optional [Callable [[Array ], bool ]]
172
+ filter_ : Optional [Callable [[Array ], bool ]] = None
173
+ base_func : Optional [FromDtypeFunc ] = None
182
174
183
175
def __add__ (self , other : BoundFromDtype ) -> BoundFromDtype :
184
176
for k in self .kwargs .keys ():
@@ -189,17 +181,28 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
189
181
if self .filter_ is not None and other .filter_ is not None :
190
182
filter_ = lambda i : self .filter_ (i ) and other .filter_ (i )
191
183
else :
192
- try :
193
- filter_ = next (
194
- f for f in [ self . filter_ , other .filter_ ] if f is not None
195
- )
196
- except StopIteration :
184
+ if self . filter_ is not None :
185
+ filter_ = self . filter_
186
+ elif other .filter_ is not None :
187
+ filter_ = other . filter_
188
+ else :
197
189
filter_ = None
198
190
199
- return BoundFromDtype (kwargs , filter_ )
191
+ # sanity check
192
+ assert not (self .base_func is not None and other .base_func is not None )
193
+ if self .base_func is not None :
194
+ base_func = self .base_func
195
+ elif other .base_func is not None :
196
+ base_func = other .base_func
197
+ else :
198
+ base_func = None
199
+
200
+ return BoundFromDtype (kwargs , filter_ , base_func )
200
201
201
- def __call__ (self , dtype : DataType ) -> st .SearchStrategy [float ]:
202
- strat = xps .from_dtype (dtype , ** self .kwargs )
202
+ def __call__ (self , dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
203
+ assert len (kw ) == 0 # sanity check
204
+ from_dtype = self .base_func or xps .from_dtype
205
+ strat = from_dtype (dtype , ** self .kwargs )
203
206
if self .filter_ is not None :
204
207
strat = strat .filter (self .filter_ )
205
208
return strat
@@ -295,22 +298,18 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
295
298
if not not_cond :
296
299
kwargs = {"allow_nan" : False , "allow_infinity" : False }
297
300
filter_ = lambda n : n != 0
298
- elif "integer value" in cond_str :
299
- raise ValueError (
300
- "integer values are only specified in dual cases, "
301
- "which cannot be handled in parse_cond()"
302
- )
303
- # elif cond_str == "an integer value":
304
- # cond = lambda i: i.is_integer()
305
- # expr_template = "{}.is_integer()"
306
- # if not not_cond:
307
- # from_dtype = integers_from_dtype # type: ignore
308
- # elif cond_str == "an odd integer value":
309
- # cond = lambda i: i.is_integer() and i % 2 == 1
310
- # expr_template = "{}.is_integer() and {} % 2 == 1"
311
- # if not not_cond:
312
- # def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
313
- # return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
301
+ elif cond_str == "an integer value" :
302
+ cond = lambda i : i .is_integer ()
303
+ expr_template = "{}.is_integer()"
304
+ if not not_cond :
305
+ from_dtype = integers_from_dtype # type: ignore
306
+ elif cond_str == "an odd integer value" :
307
+ cond = lambda i : i .is_integer () and i % 2 == 1
308
+ expr_template = "{}.is_integer() and {} % 2 == 1"
309
+ if not not_cond :
310
+
311
+ def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
312
+ return integers_from_dtype (dtype , ** kw ).filter (lambda n : n % 2 == 1 )
314
313
315
314
else :
316
315
raise ValueParseError (cond_str )
@@ -329,7 +328,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
329
328
kwargs = {}
330
329
filter_ = cond
331
330
assert kwargs is not None
332
- return cond , expr_template , BoundFromDtype (kwargs , filter_ )
331
+ return cond , expr_template , BoundFromDtype (kwargs , filter_ , from_dtype )
333
332
334
333
335
334
def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
@@ -531,25 +530,9 @@ def noop(n: float) -> float:
531
530
return n
532
531
533
532
534
- def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
535
- for k in kw .keys ():
536
- # sanity check
537
- assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
538
- m , M = dh .dtype_ranges [dtype ]
539
- if "min_value" in kw .keys ():
540
- m = kw ["min_value" ]
541
- if "exclude_min" in kw .keys ():
542
- m += 1
543
- if "max_value" in kw .keys ():
544
- M = kw ["max_value" ]
545
- if "exclude_max" in kw .keys ():
546
- M -= 1
547
- return st .integers (math .ceil (m ), math .floor (M )).map (float )
548
-
549
-
550
533
def make_binary_cond (
551
534
cond_arg : BinaryCondArg ,
552
- unary_check : UnaryCheck ,
535
+ unary_cond : UnaryCheck ,
553
536
* ,
554
537
input_wrapper : Optional [Callable [[float ], float ]] = None ,
555
538
) -> BinaryCond :
@@ -559,22 +542,22 @@ def make_binary_cond(
559
542
if cond_arg == BinaryCondArg .FIRST :
560
543
561
544
def partial_cond (i1 : float , i2 : float ) -> bool :
562
- return unary_check (input_wrapper (i1 ))
545
+ return unary_cond (input_wrapper (i1 ))
563
546
564
547
elif cond_arg == BinaryCondArg .SECOND :
565
548
566
549
def partial_cond (i1 : float , i2 : float ) -> bool :
567
- return unary_check (input_wrapper (i2 ))
550
+ return unary_cond (input_wrapper (i2 ))
568
551
569
552
elif cond_arg == BinaryCondArg .BOTH :
570
553
571
554
def partial_cond (i1 : float , i2 : float ) -> bool :
572
- return unary_check (input_wrapper (i1 )) and unary_check (input_wrapper (i2 ))
555
+ return unary_cond (input_wrapper (i1 )) and unary_cond (input_wrapper (i2 ))
573
556
574
557
else :
575
558
576
559
def partial_cond (i1 : float , i2 : float ) -> bool :
577
- return unary_check (input_wrapper (i1 )) or unary_check (input_wrapper (i2 ))
560
+ return unary_cond (input_wrapper (i1 )) or unary_cond (input_wrapper (i2 ))
578
561
579
562
return partial_cond
580
563
@@ -631,11 +614,26 @@ def check_result(i1: float, i2: float, result: float) -> bool:
631
614
return check_result
632
615
633
616
634
- def parse_binary_case (case_m : Match ) -> BinaryCase :
635
- cond_strs = r_cond_sep .split (case_m .group (1 ))
617
+ def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
618
+ for k in kw .keys ():
619
+ # sanity check
620
+ assert k in ["min_value" , "max_value" , "exclude_min" , "exclude_max" ]
621
+ m , M = dh .dtype_ranges [dtype ]
622
+ if "min_value" in kw .keys ():
623
+ m = kw ["min_value" ]
624
+ if "exclude_min" in kw .keys ():
625
+ m += 1
626
+ if "max_value" in kw .keys ():
627
+ M = kw ["max_value" ]
628
+ if "exclude_max" in kw .keys ():
629
+ M -= 1
630
+ return st .integers (math .ceil (m ), math .floor (M )).map (float )
636
631
637
- if len (cond_strs ) > 2 :
638
- raise ValueParseError (", " .join (cond_strs ))
632
+
633
+ def parse_binary_case (case_str : str ) -> BinaryCase :
634
+ case_m = r_binary_case .match (case_str )
635
+ assert case_m is not None # sanity check
636
+ cond_strs = r_cond_sep .split (case_m .group (1 ))
639
637
640
638
partial_conds = []
641
639
partial_exprs = []
@@ -678,7 +676,7 @@ def partial_cond(i1: float, i2: float) -> bool:
678
676
return math .copysign (1 , i1 ) != math .copysign (1 , i2 )
679
677
680
678
else :
681
- unary_check , expr_template , cond_from_dtype = parse_cond (value_str )
679
+ unary_cond , expr_template , cond_from_dtype = parse_cond (value_str )
682
680
# Do not define partial_cond via the def keyword, as one
683
681
# partial_cond definition can mess up previous definitions
684
682
# in the partial_conds list. This is a hard-limitation of
@@ -707,7 +705,7 @@ def partial_cond(i1: float, i2: float) -> bool:
707
705
else :
708
706
raise ValueParseError (input_str )
709
707
partial_cond = make_binary_cond ( # type: ignore
710
- cond_arg , unary_check , input_wrapper = input_wrapper
708
+ cond_arg , unary_cond , input_wrapper = input_wrapper
711
709
)
712
710
if cond_arg == BinaryCondArg .FIRST :
713
711
x1_cond_from_dtypes .append (cond_from_dtype )
@@ -749,15 +747,15 @@ def cond(i1: float, i2: float) -> bool:
749
747
else :
750
748
# sanity check
751
749
assert all (isinstance (fd , BoundFromDtype ) for fd in x1_cond_from_dtypes )
752
- x1_cond_from_dtype = sum (x1_cond_from_dtypes )
750
+ x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ({}, None ) )
753
751
if len (x2_cond_from_dtypes ) == 0 :
754
752
x2_cond_from_dtype = xps .from_dtype
755
753
elif len (x2_cond_from_dtypes ) == 1 :
756
754
x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
757
755
else :
758
756
# sanity check
759
757
assert all (isinstance (fd , BoundFromDtype ) for fd in x2_cond_from_dtypes )
760
- x2_cond_from_dtype = sum (x2_cond_from_dtypes )
758
+ x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ({}, None ) )
761
759
762
760
return BinaryCase (
763
761
cond_expr = cond_expr ,
@@ -788,7 +786,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
788
786
continue
789
787
if m := r_binary_case .match (case_str ):
790
788
try :
791
- case = parse_binary_case (m )
789
+ case = parse_binary_case (case_str )
792
790
cases .append (case )
793
791
except ValueParseError as e :
794
792
warn (f"not machine-readable: '{ e .value } '" )
0 commit comments