Skip to content

Commit 67891b8

Browse files
committed
Factor out UnaryCase.from_strings()
1 parent ddb287a commit 67891b8

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

array_api_tests/test_special_cases.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
311311
"""
312312
Wraps an elements strategy as a xps.from_dtype()-like function
313313
"""
314+
314315
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
315316
assert len(kw) == 0 # sanity check
316317
return strat
@@ -553,23 +554,6 @@ class UnaryCase(Case):
553554
cond: UnaryCheck
554555
check_result: UnaryResultCheck
555556

556-
@classmethod
557-
def from_strings(cls, cond_str: str, result_str: str):
558-
cond, cond_expr_template, cond_from_dtype = parse_cond(cond_str)
559-
cond_expr = cond_expr_template.replace("{}", "x_i")
560-
_check_result, result_expr = parse_result(result_str)
561-
562-
def check_result(i: float, result: float) -> bool:
563-
return _check_result(result)
564-
565-
return cls(
566-
cond_expr=cond_expr,
567-
cond=cond,
568-
cond_from_dtype=cond_from_dtype,
569-
result_expr=result_expr,
570-
check_result=check_result,
571-
)
572-
573557

574558
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
575559
r_even_int_round_case = re.compile(
@@ -578,7 +562,7 @@ def check_result(i: float, result: float) -> bool:
578562
)
579563

580564

581-
def trailing_halves_from_dtype(dtype: DataType):
565+
def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
582566
m, M = dh.dtype_ranges[dtype]
583567
return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5)
584568

@@ -594,6 +578,13 @@ def trailing_halves_from_dtype(dtype: DataType):
594578
)
595579

596580

581+
def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck:
582+
def check_result(i: float, result: float) -> bool:
583+
return check_just_result(result)
584+
585+
return check_result
586+
587+
597588
def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
598589
match = r_special_cases.search(docstring)
599590
if match is None:
@@ -608,10 +599,22 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
608599
continue
609600
if m := r_unary_case.search(case):
610601
try:
611-
case = UnaryCase.from_strings(*m.groups())
602+
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
603+
_check_result, result_expr = parse_result(m.group(2))
612604
except ParseError as e:
613605
warn(f"not machine-readable: '{e.value}'")
614606
continue
607+
cond_expr = cond_expr_template.replace("{}", "x_i")
608+
# Do not define check_result in this function's body - see
609+
# parse_binary_case comment.
610+
check_result = make_unary_check_result(_check_result)
611+
case = UnaryCase(
612+
cond_expr=cond_expr,
613+
cond=cond,
614+
cond_from_dtype=cond_from_dtype,
615+
result_expr=result_expr,
616+
check_result=check_result,
617+
)
615618
cases.append(case)
616619
elif m := r_even_int_round_case.search(case):
617620
cases.append(even_int_round_case)
@@ -741,7 +744,7 @@ def check_result(i1: float, i2: float, result: float) -> bool:
741744
return check_result
742745

743746

744-
def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
747+
def make_binary_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
745748
def check_result(i1: float, i2: float, result: float) -> bool:
746749
return check_just_result(result)
747750

@@ -843,12 +846,12 @@ def partial_cond(i1: float, i2: float) -> bool:
843846

844847
else:
845848
unary_cond, expr_template, cond_from_dtype = parse_cond(value_str)
846-
# Do not define partial_cond via the def keyword, as one
847-
# partial_cond definition can mess up previous definitions
848-
# in the partial_conds list. This is a hard-limitation of
849-
# using local functions with the same name and that use the same
850-
# outer variables (i.e. unary_cond). Use def in a called
851-
# function avoids this problem.
849+
# Do not define partial_cond via the def keyword or lambda
850+
# expressions, as one partial_cond definition can mess up
851+
# previous definitions in the partial_conds list. This is a
852+
# hard-limitation of using local functions with the same name
853+
# and that use the same outer variables (i.e. unary_cond). Use
854+
# def in a called function avoids this problem.
852855
input_wrapper = None
853856
if m := r_input.match(input_str):
854857
x_no = m.group(1)
@@ -924,7 +927,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
924927
if result_m is None:
925928
raise ParseError(case_m.group(2))
926929
result_str = result_m.group(1)
927-
# Like with partial_cond, do not define check_result via the def keyword
930+
# Like with partial_cond, do not define check_result in this function's body.
928931
if m := r_array_element.match(result_str):
929932
sign, x_no = m.groups()
930933
result_expr = f"{sign}x{x_no}_i"
@@ -933,7 +936,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
933936
)
934937
else:
935938
_check_result, result_expr = parse_result(result_m.group(1))
936-
check_result = make_check_result(_check_result)
939+
check_result = make_binary_check_result(_check_result)
937940

938941
cond_expr = " and ".join(partial_exprs)
939942

0 commit comments

Comments
 (0)