Skip to content

Commit 420a549

Browse files
committed
Cover binary cases with two unary conds for one array
1 parent 9c7c051 commit 420a549

File tree

1 file changed

+64
-66
lines changed

1 file changed

+64
-66
lines changed

array_api_tests/test_special_cases.py

Lines changed: 64 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,7 @@
66
from dataclasses import dataclass
77
from decimal import ROUND_HALF_EVEN, Decimal
88
from enum import Enum, auto
9-
from typing import (
10-
Any,
11-
Callable,
12-
Dict,
13-
List,
14-
Match,
15-
Optional,
16-
Protocol,
17-
Tuple,
18-
)
9+
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
1910
from warnings import warn
2011

2112
import pytest
@@ -178,7 +169,8 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
178169
@dataclass
179170
class BoundFromDtype(FromDtypeFunc):
180171
kwargs: Dict[str, Any]
181-
filter_: Optional[Callable[[Array], bool]]
172+
filter_: Optional[Callable[[Array], bool]] = None
173+
base_func: Optional[FromDtypeFunc] = None
182174

183175
def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
184176
for k in self.kwargs.keys():
@@ -189,17 +181,28 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
189181
if self.filter_ is not None and other.filter_ is not None:
190182
filter_ = lambda i: self.filter_(i) and other.filter_(i)
191183
else:
192-
try:
193-
filter_ = next(
194-
f for f in [self.filter_, other.filter_] if f is not None
195-
)
196-
except StopIteration:
184+
if self.filter_ is not None:
185+
filter_ = self.filter_
186+
elif other.filter_ is not None:
187+
filter_ = other.filter_
188+
else:
197189
filter_ = None
198190

199-
return BoundFromDtype(kwargs, filter_)
191+
# sanity check
192+
assert not (self.base_func is not None and other.base_func is not None)
193+
if self.base_func is not None:
194+
base_func = self.base_func
195+
elif other.base_func is not None:
196+
base_func = other.base_func
197+
else:
198+
base_func = None
199+
200+
return BoundFromDtype(kwargs, filter_, base_func)
200201

201-
def __call__(self, dtype: DataType) -> st.SearchStrategy[float]:
202-
strat = xps.from_dtype(dtype, **self.kwargs)
202+
def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
203+
assert len(kw) == 0 # sanity check
204+
from_dtype = self.base_func or xps.from_dtype
205+
strat = from_dtype(dtype, **self.kwargs)
203206
if self.filter_ is not None:
204207
strat = strat.filter(self.filter_)
205208
return strat
@@ -295,22 +298,18 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
295298
if not not_cond:
296299
kwargs = {"allow_nan": False, "allow_infinity": False}
297300
filter_ = lambda n: n != 0
298-
elif "integer value" in cond_str:
299-
raise ValueError(
300-
"integer values are only specified in dual cases, "
301-
"which cannot be handled in parse_cond()"
302-
)
303-
# elif cond_str == "an integer value":
304-
# cond = lambda i: i.is_integer()
305-
# expr_template = "{}.is_integer()"
306-
# if not not_cond:
307-
# from_dtype = integers_from_dtype # type: ignore
308-
# elif cond_str == "an odd integer value":
309-
# cond = lambda i: i.is_integer() and i % 2 == 1
310-
# expr_template = "{}.is_integer() and {} % 2 == 1"
311-
# if not not_cond:
312-
# def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
313-
# return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
301+
elif cond_str == "an integer value":
302+
cond = lambda i: i.is_integer()
303+
expr_template = "{}.is_integer()"
304+
if not not_cond:
305+
from_dtype = integers_from_dtype # type: ignore
306+
elif cond_str == "an odd integer value":
307+
cond = lambda i: i.is_integer() and i % 2 == 1
308+
expr_template = "{}.is_integer() and {} % 2 == 1"
309+
if not not_cond:
310+
311+
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
312+
return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
314313

315314
else:
316315
raise ValueParseError(cond_str)
@@ -329,7 +328,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
329328
kwargs = {}
330329
filter_ = cond
331330
assert kwargs is not None
332-
return cond, expr_template, BoundFromDtype(kwargs, filter_)
331+
return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype)
333332

334333

335334
def parse_result(result_str: str) -> Tuple[UnaryCheck, str]:
@@ -531,25 +530,9 @@ def noop(n: float) -> float:
531530
return n
532531

533532

534-
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
535-
for k in kw.keys():
536-
# sanity check
537-
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
538-
m, M = dh.dtype_ranges[dtype]
539-
if "min_value" in kw.keys():
540-
m = kw["min_value"]
541-
if "exclude_min" in kw.keys():
542-
m += 1
543-
if "max_value" in kw.keys():
544-
M = kw["max_value"]
545-
if "exclude_max" in kw.keys():
546-
M -= 1
547-
return st.integers(math.ceil(m), math.floor(M)).map(float)
548-
549-
550533
def make_binary_cond(
551534
cond_arg: BinaryCondArg,
552-
unary_check: UnaryCheck,
535+
unary_cond: UnaryCheck,
553536
*,
554537
input_wrapper: Optional[Callable[[float], float]] = None,
555538
) -> BinaryCond:
@@ -559,22 +542,22 @@ def make_binary_cond(
559542
if cond_arg == BinaryCondArg.FIRST:
560543

561544
def partial_cond(i1: float, i2: float) -> bool:
562-
return unary_check(input_wrapper(i1))
545+
return unary_cond(input_wrapper(i1))
563546

564547
elif cond_arg == BinaryCondArg.SECOND:
565548

566549
def partial_cond(i1: float, i2: float) -> bool:
567-
return unary_check(input_wrapper(i2))
550+
return unary_cond(input_wrapper(i2))
568551

569552
elif cond_arg == BinaryCondArg.BOTH:
570553

571554
def partial_cond(i1: float, i2: float) -> bool:
572-
return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2))
555+
return unary_cond(input_wrapper(i1)) and unary_cond(input_wrapper(i2))
573556

574557
else:
575558

576559
def partial_cond(i1: float, i2: float) -> bool:
577-
return unary_check(input_wrapper(i1)) or unary_check(input_wrapper(i2))
560+
return unary_cond(input_wrapper(i1)) or unary_cond(input_wrapper(i2))
578561

579562
return partial_cond
580563

@@ -631,11 +614,26 @@ def check_result(i1: float, i2: float, result: float) -> bool:
631614
return check_result
632615

633616

634-
def parse_binary_case(case_m: Match) -> BinaryCase:
635-
cond_strs = r_cond_sep.split(case_m.group(1))
617+
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
618+
for k in kw.keys():
619+
# sanity check
620+
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
621+
m, M = dh.dtype_ranges[dtype]
622+
if "min_value" in kw.keys():
623+
m = kw["min_value"]
624+
if "exclude_min" in kw.keys():
625+
m += 1
626+
if "max_value" in kw.keys():
627+
M = kw["max_value"]
628+
if "exclude_max" in kw.keys():
629+
M -= 1
630+
return st.integers(math.ceil(m), math.floor(M)).map(float)
636631

637-
if len(cond_strs) > 2:
638-
raise ValueParseError(", ".join(cond_strs))
632+
633+
def parse_binary_case(case_str: str) -> BinaryCase:
634+
case_m = r_binary_case.match(case_str)
635+
assert case_m is not None # sanity check
636+
cond_strs = r_cond_sep.split(case_m.group(1))
639637

640638
partial_conds = []
641639
partial_exprs = []
@@ -678,7 +676,7 @@ def partial_cond(i1: float, i2: float) -> bool:
678676
return math.copysign(1, i1) != math.copysign(1, i2)
679677

680678
else:
681-
unary_check, expr_template, cond_from_dtype = parse_cond(value_str)
679+
unary_cond, expr_template, cond_from_dtype = parse_cond(value_str)
682680
# Do not define partial_cond via the def keyword, as one
683681
# partial_cond definition can mess up previous definitions
684682
# in the partial_conds list. This is a hard-limitation of
@@ -707,7 +705,7 @@ def partial_cond(i1: float, i2: float) -> bool:
707705
else:
708706
raise ValueParseError(input_str)
709707
partial_cond = make_binary_cond( # type: ignore
710-
cond_arg, unary_check, input_wrapper=input_wrapper
708+
cond_arg, unary_cond, input_wrapper=input_wrapper
711709
)
712710
if cond_arg == BinaryCondArg.FIRST:
713711
x1_cond_from_dtypes.append(cond_from_dtype)
@@ -749,15 +747,15 @@ def cond(i1: float, i2: float) -> bool:
749747
else:
750748
# sanity check
751749
assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes)
752-
x1_cond_from_dtype = sum(x1_cond_from_dtypes)
750+
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}, None))
753751
if len(x2_cond_from_dtypes) == 0:
754752
x2_cond_from_dtype = xps.from_dtype
755753
elif len(x2_cond_from_dtypes) == 1:
756754
x2_cond_from_dtype = x2_cond_from_dtypes[0]
757755
else:
758756
# sanity check
759757
assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes)
760-
x2_cond_from_dtype = sum(x2_cond_from_dtypes)
758+
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}, None))
761759

762760
return BinaryCase(
763761
cond_expr=cond_expr,
@@ -788,7 +786,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
788786
continue
789787
if m := r_binary_case.match(case_str):
790788
try:
791-
case = parse_binary_case(m)
789+
case = parse_binary_case(case_str)
792790
cases.append(case)
793791
except ValueParseError as e:
794792
warn(f"not machine-readable: '{e.value}'")

0 commit comments

Comments
 (0)