Skip to content

Commit ddb287a

Browse files
committed
Document parse_result()
1 parent cb810b9 commit ddb287a

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

array_api_tests/test_special_cases.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ class BoundFromDtype(FromDtypeFunc):
262262
263263
>>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0})
264264
>>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42)
265-
>>> from_dtype = gt_0_from_dtype + not_42_from_dtype
265+
>>> gt_0_from_dtype + not_42_from_dtype
266+
BoundFromDtype(kwargs={'min_value': 0}, filter=<lambda>(i))
266267
267268
"""
268269

@@ -329,17 +330,14 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
329330
330331
e.g.
331332
332-
>>> cond_func, expr_template, cond_from_dtype = parse_cond(
333-
... 'greater than ``0``'
334-
... )
335-
>>> expr_template.replace('{}', 'x_i')
333+
>>> cond, expr_template, from_dtype = parse_cond('greater than ``0``')
336334
>>> expr_template.replace('{}', 'x_i')
337335
'x_i > 0'
338-
>>> cond_func(42)
336+
>>> cond(42)
339337
True
340-
>>> cond_func(-128)
338+
>>> cond(-128)
341339
False
342-
>>> strategy = cond_from_dtype(xp.float64)
340+
>>> strategy = from_dtype(xp.float64)
343341
>>> for _ in range(5):
344342
... print(strategy.example())
345343
1.
@@ -387,7 +385,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
387385
v1 = parse_value(m.group(1))
388386
v2 = parse_value(m.group(2))
389387
cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
390-
expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2)
388+
expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")"
391389
if not not_cond:
392390
strat = st.sampled_from([v1, v2])
393391
elif cond_str in ["finite", "a finite number"]:
@@ -469,6 +467,24 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
469467

470468

471469
def parse_result(result_str: str) -> Tuple[UnaryCheck, str]:
470+
"""
471+
Parses a Sphinx-formatted result string to return:
472+
473+
1. A function which takes an input and returns True if it is the expected
474+
result (or meets the condition of the expected result), otherwise False.
475+
2. A string that expresses the result.
476+
477+
e.g.
478+
479+
>>> check_result, expr = parse_result('``42``')
480+
>>> expr_template.replace('{}', 'x_i')
481+
'42'
482+
>>> check_result(7)
483+
False
484+
>>> check_result(42)
485+
True
486+
487+
"""
472488
if m := r_code.match(result_str):
473489
value = parse_value(m.group(1))
474490
check_result = make_strict_eq(value) # type: ignore

0 commit comments

Comments
 (0)