Skip to content

Commit e6f9548

Browse files
committed
Cover "equal to" cases (as opposed to "is" cases)
1 parent 420a549 commit e6f9548

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

array_api_tests/test_special_cases.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
BinaryCheck = Callable[[float, float], bool]
3030

3131

32-
def make_eq(v: float) -> UnaryCheck:
32+
def make_strict_eq(v: float) -> UnaryCheck:
3333
if math.isnan(v):
3434
return math.isnan
3535
if v == 0:
@@ -38,14 +38,14 @@ def make_eq(v: float) -> UnaryCheck:
3838
else:
3939
return ph.is_neg_zero
4040

41-
def eq(i: float) -> bool:
41+
def strict_eq(i: float) -> bool:
4242
return i == v
4343

44-
return eq
44+
return strict_eq
4545

4646

4747
def make_neq(v: float) -> UnaryCheck:
48-
eq = make_eq(v)
48+
eq = make_strict_eq(v)
4949

5050
def neq(i: float) -> bool:
5151
return not eq(i)
@@ -154,7 +154,8 @@ def parse_inline_code(inline_code: str) -> float:
154154
raise ValueParseError(inline_code)
155155

156156

157-
r_not = re.compile("not (?:equal to )?(.+)")
157+
r_not = re.compile("not (.+)")
158+
r_equal_to = re.compile(f"equal to {r_code.pattern}")
158159
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
159160
r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}")
160161
r_gt = re.compile(f"greater than {r_code.pattern}")
@@ -217,9 +218,6 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
217218

218219

219220
def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
220-
if "equal to" in cond_str:
221-
raise ValueParseError(cond_str) # TODO
222-
223221
if m := r_not.match(cond_str):
224222
cond_str = m.group(1)
225223
not_cond = True
@@ -232,10 +230,15 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
232230
strat = None
233231
if m := r_code.match(cond_str):
234232
value = parse_value(m.group(1))
235-
cond = make_eq(value)
233+
cond = make_strict_eq(value)
236234
expr_template = "{} == " + m.group(1)
237235
if not not_cond:
238236
strat = st.just(value)
237+
elif m := r_equal_to.match(cond_str):
238+
value = parse_value(m.group(1))
239+
assert not math.isnan(value) # sanity check
240+
cond = lambda i: i == value
241+
expr_template = "{} == " + m.group(1)
239242
elif m := r_gt.match(cond_str):
240243
value = parse_value(m.group(1))
241244
cond = make_gt(value)
@@ -251,7 +254,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
251254
elif m := r_either_code.match(cond_str):
252255
v1 = parse_value(m.group(1))
253256
v2 = parse_value(m.group(2))
254-
cond = make_or(make_eq(v1), make_eq(v2))
257+
cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
255258
expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2)
256259
if not not_cond:
257260
strat = st.sampled_from([v1, v2])
@@ -334,7 +337,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
334337
def parse_result(result_str: str) -> Tuple[UnaryCheck, str]:
335338
if m := r_code.match(result_str):
336339
value = parse_value(m.group(1))
337-
check_result = make_eq(value) # type: ignore
340+
check_result = make_strict_eq(value) # type: ignore
338341
expr = m.group(1)
339342
elif m := r_approx_value.match(result_str):
340343
value = parse_value(m.group(1))
@@ -573,13 +576,13 @@ def make_eq_other_input_cond(
573576
if eq_to == BinaryCondArg.FIRST:
574577

575578
def cond(i1: float, i2: float) -> bool:
576-
eq = make_eq(input_wrapper(i1))
579+
eq = make_strict_eq(input_wrapper(i1))
577580
return eq(i2)
578581

579582
elif eq_to == BinaryCondArg.SECOND:
580583

581584
def cond(i1: float, i2: float) -> bool:
582-
eq = make_eq(input_wrapper(i2))
585+
eq = make_strict_eq(input_wrapper(i2))
583586
return eq(i1)
584587

585588
else:
@@ -599,13 +602,13 @@ def make_eq_input_check_result(
599602
if eq_to == BinaryCondArg.FIRST:
600603

601604
def check_result(i1: float, i2: float, result: float) -> bool:
602-
eq = make_eq(input_wrapper(i1))
605+
eq = make_strict_eq(input_wrapper(i1))
603606
return eq(result)
604607

605608
elif eq_to == BinaryCondArg.SECOND:
606609

607610
def check_result(i1: float, i2: float, result: float) -> bool:
608-
eq = make_eq(input_wrapper(i2))
611+
eq = make_strict_eq(input_wrapper(i2))
609612
return eq(result)
610613

611614
else:

0 commit comments

Comments
 (0)