Skip to content

Commit eaa3b05

Browse files
committed
Document parse_binary_docstring()
1 parent 6c802c3 commit eaa3b05

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

array_api_tests/test_special_cases.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
590590
Parses a Sphinx-formatted docstring of a unary function to return a list of
591591
codified unary cases, e.g.
592592
593-
>>> def sqrt(x: array, /) -> array:
593+
>>> def sqrt(x):
594594
... '''
595595
... Calculates the square root
596596
...
@@ -607,14 +607,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
607607
... Parameters
608608
... ----------
609609
... x: array
610-
... input array. Should have a floating-point data type
610+
... input array
611611
...
612612
... Returns
613613
... -------
614614
... out: array
615615
... an array containing the square root of each element in ``x``
616616
... '''
617-
... ...
617+
...
618618
>>> unary_cases = parse_unary_docstring(sqrt.__doc__)
619619
>>> for case in unary_cases:
620620
... print(repr(case))
@@ -1014,6 +1014,46 @@ def cond(i1: float, i2: float) -> bool:
10141014

10151015

10161016
def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1017+
"""
1018+
Parses a Sphinx-formatted docstring of a binary function to return a list of
1019+
codified binary cases, e.g.
1020+
1021+
>>> def logaddexp(x1, x2):
1022+
... '''
1023+
... Calculates the logarithm of the sum of exponentiations
1024+
...
1025+
... **Special Cases**
1026+
...
1027+
... For floating-point operands,
1028+
...
1029+
... - If either ``x1_i`` or ``x2_i`` is ``NaN``, the result is ``NaN``.
1030+
... - If ``x1_i`` is ``+infinity`` and ``x2_i`` is not ``NaN``, the
1031+
... result is ``+infinity``.
1032+
... - If ``x1_i`` is not ``NaN`` and ``x2_i`` is ``+infinity``, the
1033+
... result is ``+infinity``.
1034+
...
1035+
... Parameters
1036+
... ----------
1037+
... x1: array
1038+
... first input array
1039+
... x2: array
1040+
... second input array
1041+
...
1042+
... Returns
1043+
... -------
1044+
... out: array
1045+
... an array containing the results
1046+
... '''
1047+
...
1048+
>>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1049+
>>> for case in binary_cases:
1050+
... print(repr(case))
1051+
BinaryCase(x1_i == NaN or x2_i == NaN -> NaN)
1052+
BinaryCase(x1_i == +infinity and not x2_i == NaN -> +infinity)
1053+
BinaryCase(not x1_i == NaN and x2_i == +infinity -> +infinity)
1054+
1055+
"""
1056+
10171057
match = r_special_cases.search(docstring)
10181058
if match is None:
10191059
return []

0 commit comments

Comments
 (0)