@@ -262,7 +262,8 @@ class BoundFromDtype(FromDtypeFunc):
262
262
263
263
>>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0})
264
264
>>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42)
265
- >>> from_dtype = gt_0_from_dtype + not_42_from_dtype
265
+ >>> gt_0_from_dtype + not_42_from_dtype
266
+ BoundFromDtype(kwargs={'min_value': 0}, filter=<lambda>(i))
266
267
267
268
"""
268
269
@@ -329,17 +330,14 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
329
330
330
331
e.g.
331
332
332
- >>> cond_func, expr_template, cond_from_dtype = parse_cond(
333
- ... 'greater than ``0``'
334
- ... )
335
- >>> expr_template.replace('{}', 'x_i')
333
+ >>> cond, expr_template, from_dtype = parse_cond('greater than ``0``')
336
334
>>> expr_template.replace('{}', 'x_i')
337
335
'x_i > 0'
338
- >>> cond_func (42)
336
+ >>> cond (42)
339
337
True
340
- >>> cond_func (-128)
338
+ >>> cond (-128)
341
339
False
342
- >>> strategy = cond_from_dtype (xp.float64)
340
+ >>> strategy = from_dtype (xp.float64)
343
341
>>> for _ in range(5):
344
342
... print(strategy.example())
345
343
1.
@@ -387,7 +385,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
387
385
v1 = parse_value (m .group (1 ))
388
386
v2 = parse_value (m .group (2 ))
389
387
cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
390
- expr_template = "{} == " + m .group (1 ) + " or {} == " + m .group (2 )
388
+ expr_template = "( {} == " + m .group (1 ) + " or {} == " + m .group (2 ) + ")"
391
389
if not not_cond :
392
390
strat = st .sampled_from ([v1 , v2 ])
393
391
elif cond_str in ["finite" , "a finite number" ]:
@@ -469,6 +467,24 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
469
467
470
468
471
469
def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
470
+ """
471
+ Parses a Sphinx-formatted result string to return:
472
+
473
+ 1. A function which takes an input and returns True if it is the expected
474
+ result (or meets the condition of the expected result), otherwise False.
475
+ 2. A string that expresses the result.
476
+
477
+ e.g.
478
+
479
+ >>> check_result, expr = parse_result('``42``')
480
+ >>> expr_template.replace('{}', 'x_i')
481
+ '42'
482
+ >>> check_result(7)
483
+ False
484
+ >>> check_result(42)
485
+ True
486
+
487
+ """
472
488
if m := r_code .match (result_str ):
473
489
value = parse_value (m .group (1 ))
474
490
check_result = make_strict_eq (value ) # type: ignore
0 commit comments