Skip to content

Commit 5e62058

Browse files
committed
Print the function name for non-machine readable special cases
1 parent ebb4f37 commit 5e62058

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

array_api_tests/test_special_cases.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def check_result(i: float, result: float) -> bool:
629629
return check_result
630630

631631

632-
def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
632+
def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
633633
"""
634634
Parses a Sphinx-formatted docstring of a unary function to return a list of
635635
codified unary cases, e.g.
@@ -660,7 +660,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
660660
... '''
661661
...
662662
>>> case_block = r_case_block.search(sqrt.__doc__).group(1)
663-
>>> unary_cases = parse_unary_case_block(case_block)
663+
>>> unary_cases = parse_unary_case_block(case_block, 'sqrt')
664664
>>> for case in unary_cases:
665665
... print(repr(case))
666666
UnaryCase(<x_i < 0 -> NaN>)
@@ -691,7 +691,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
691691
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
692692
_check_result, result_expr = parse_result(m.group(2))
693693
except ParseError as e:
694-
warn(f"not machine-readable: '{e.value}'")
694+
warn(f"case for {func_name} not machine-readable: '{e.value}'")
695695
continue
696696
cond_expr = cond_expr_template.replace("{}", "x_i")
697697
# Do not define check_result in this function's body - see
@@ -708,7 +708,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
708708
cases.append(case)
709709
else:
710710
if not r_remaining_case.search(case_str):
711-
warn(f"case not machine-readable: '{case_str}'")
711+
warn(f"case for {func_name} not machine-readable: '{case_str}'")
712712
return cases
713713

714714

@@ -1102,7 +1102,7 @@ def cond(i1: float, i2: float) -> bool:
11021102
r_redundant_case = re.compile("result.+determined by the rule already stated above")
11031103

11041104

1105-
def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
1105+
def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]:
11061106
"""
11071107
Parses a Sphinx-formatted docstring of a binary function to return a list of
11081108
codified binary cases, e.g.
@@ -1133,7 +1133,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11331133
... '''
11341134
...
11351135
>>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
1136-
>>> binary_cases = parse_binary_case_block(case_block)
1136+
>>> binary_cases = parse_binary_case_block(case_block, 'logaddexp')
11371137
>>> for case in binary_cases:
11381138
... print(repr(case))
11391139
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
@@ -1151,10 +1151,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11511151
case = parse_binary_case(case_str)
11521152
cases.append(case)
11531153
except ParseError as e:
1154-
warn(f"not machine-readable: '{e.value}'")
1154+
warn(f"case for {func_name} not machine-readable: '{e.value}'")
11551155
else:
11561156
if not r_remaining_case.match(case_str):
1157-
warn(f"case not machine-readable: '{case_str}'")
1157+
warn(f"case for {func_name} not machine-readable: '{case_str}'")
11581158
return cases
11591159

11601160

@@ -1163,19 +1163,20 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11631163
iop_params = []
11641164
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
11651165
for stub in category_to_funcs["elementwise"]:
1166+
func_name = stub.__name__
11661167
if stub.__doc__ is None:
1167-
warn(f"{stub.__name__}() stub has no docstring")
1168+
warn(f"{func_name}() stub has no docstring")
11681169
continue
11691170
if m := r_case_block.search(stub.__doc__):
11701171
case_block = m.group(1)
11711172
else:
11721173
continue
11731174
marks = []
11741175
try:
1175-
func = getattr(xp, stub.__name__)
1176+
func = getattr(xp, func_name)
11761177
except AttributeError:
11771178
marks.append(
1178-
pytest.mark.skip(reason=f"{stub.__name__} not found in array module")
1179+
pytest.mark.skip(reason=f"{func_name} not found in array module")
11791180
)
11801181
func = None
11811182
sig = inspect.signature(stub)
@@ -1184,10 +1185,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11841185
warn(f"{func=} has no parameters")
11851186
continue
11861187
if param_names[0] == "x":
1187-
if cases := parse_unary_case_block(case_block):
1188-
name_to_func = {stub.__name__: func}
1189-
if stub.__name__ in func_to_op.keys():
1190-
op_name = func_to_op[stub.__name__]
1188+
if cases := parse_unary_case_block(case_block, func_name):
1189+
name_to_func = {func_name: func}
1190+
if func_name in func_to_op.keys():
1191+
op_name = func_to_op[func_name]
11911192
op = getattr(operator, op_name)
11921193
name_to_func[op_name] = op
11931194
for func_name, func in name_to_func.items():
@@ -1196,20 +1197,20 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11961197
p = pytest.param(func_name, func, case, id=id_)
11971198
unary_params.append(p)
11981199
else:
1199-
warn(f"Special cases found for {stub.__name__} but none were parsed")
1200+
warn(f"Special cases found for {func_name} but none were parsed")
12001201
continue
12011202
if len(sig.parameters) == 1:
12021203
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
12031204
continue
12041205
if param_names[0] == "x1" and param_names[1] == "x2":
1205-
if cases := parse_binary_case_block(case_block):
1206-
name_to_func = {stub.__name__: func}
1207-
if stub.__name__ in func_to_op.keys():
1208-
op_name = func_to_op[stub.__name__]
1206+
if cases := parse_binary_case_block(case_block, func_name):
1207+
name_to_func = {func_name: func}
1208+
if func_name in func_to_op.keys():
1209+
op_name = func_to_op[func_name]
12091210
op = getattr(operator, op_name)
12101211
name_to_func[op_name] = op
12111212
# We collect inplace operator test cases seperately
1212-
if "equal" in stub.__name__:
1213+
if "equal" in func_name:
12131214
continue
12141215
iop_name = "__i" + op_name[2:]
12151216
iop = getattr(operator, iop_name)
@@ -1223,7 +1224,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
12231224
p = pytest.param(func_name, func, case, id=id_)
12241225
binary_params.append(p)
12251226
else:
1226-
warn(f"Special cases found for {stub.__name__} but none were parsed")
1227+
warn(f"Special cases found for {func_name} but none were parsed")
12271228
continue
12281229
else:
12291230
warn(

0 commit comments

Comments
 (0)