Skip to content

Commit cb810b9

Browse files
committed
Document parse_cond
1 parent 171ee5a commit cb810b9

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

array_api_tests/test_special_cases.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ParseError(ValueError):
134134

135135
def parse_value(value_str: str) -> float:
136136
"""
137-
Parse a value string to return a float, e.g.
137+
Parses a value string to return a float, e.g.
138138
139139
>>> parse_value('1')
140140
1.
@@ -169,7 +169,7 @@ def parse_value(value_str: str) -> float:
169169

170170
def parse_inline_code(inline_code: str) -> float:
171171
"""
172-
Parse a Sphinx code string to return a float, e.g.
172+
Parses a Sphinx code string to return a float, e.g.
173173
174174
>>> parse_value('``0``')
175175
0.
@@ -208,7 +208,7 @@ class BoundFromDtype(FromDtypeFunc):
208208
209209
We can bound:
210210
211-
1. Keyword arguments that xps.from_dtype() can use, e.g.
211+
1. Keyword arguments that xps.from_dtype() can use, e.g.
212212
213213
>>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False})
214214
>>> strategy = from_dtype(xp.float64)
@@ -219,7 +219,7 @@ class BoundFromDtype(FromDtypeFunc):
219219
220220
i.e. a strategy that generates finite floats above 0
221221
222-
2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g.
222+
2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g.
223223
224224
>>> from_dtype = BoundFromDtype(filter=lambda i: i != 0)
225225
>>> strategy = from_dtype(xp.float64)
@@ -230,7 +230,7 @@ class BoundFromDtype(FromDtypeFunc):
230230
231231
i.e. a strategy that generates any floats except 0
232232
233-
3. The underlying function that returns an elements strategy from a dtype, e.g.
233+
3. The underlying function that returns an elements strategy from a dtype, e.g.
234234
235235
>>> from_dtype = BoundFromDtype(
236236
... from_dtype=lambda d: st.integers(
@@ -307,14 +307,48 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
307307

308308

309309
def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
310+
"""
311+
Wraps an elements strategy as a xps.from_dtype()-like function
312+
"""
310313
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
311-
assert kw == {} # sanity check
314+
assert len(kw) == 0 # sanity check
312315
return strat
313316

314317
return from_dtype
315318

316319

317320
def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
321+
"""
322+
Parses a Sphinx-formatted condition string to return:
323+
324+
1. A function which takes an input and returns True if it meets the
325+
condition, otherwise False.
326+
2. A string template for expressing the condition.
327+
3. A xps.from_dtype()-like function which returns a strategy that generates
328+
elements which meet the condition.
329+
330+
e.g.
331+
332+
>>> cond_func, expr_template, cond_from_dtype = parse_cond(
333+
... 'greater than ``0``'
334+
... )
335+
>>> expr_template.replace('{}', 'x_i')
336+
>>> expr_template.replace('{}', 'x_i')
337+
'x_i > 0'
338+
>>> cond_func(42)
339+
True
340+
>>> cond_func(-128)
341+
False
342+
>>> strategy = cond_from_dtype(xp.float64)
343+
>>> for _ in range(5):
344+
... print(strategy.example())
345+
1.
346+
0.1
347+
1.7976931348623155e+179
348+
inf
349+
124.978
350+
351+
"""
318352
if m := r_not.match(cond_str):
319353
cond_str = m.group(1)
320354
not_cond = True

0 commit comments

Comments
 (0)