Skip to content

Commit ca88431

Browse files
committed
Note np.array_api.can_cast() does not use np.can_cast()
Original NumPy Commit: 995f5464b6c5d8569e159a96c6af106721a4e6d5
1 parent 364b48c commit ca88431

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

array_api_strict/_data_type_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,15 @@ def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
5656
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
5757
if to not in _all_dtypes:
5858
raise TypeError(f"{to=}, but should be a dtype")
59+
# Note: We avoid np.can_cast() as it has discrepancies with the array API.
60+
# See https://github.com/numpy/numpy/issues/20870
5961
try:
62+
# We promote `from_` and `to` together. We then check if the promoted
63+
# dtype is `to`, which indicates if `from_` can (up)cast to `to`.
6064
dtype = _result_type(from_, to)
6165
return to == dtype
6266
except TypeError:
67+
# _result_type() raises if the dtypes don't promote together
6368
return False
6469

6570

array_api_strict/tests/test_data_type_functions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
[
99
(xp.int8, xp.int16, True),
1010
(xp.int16, xp.int8, False),
11-
# np.can_cast has discrepancies with the Array API
12-
# See https://github.com/numpy/numpy/issues/20870
1311
(xp.bool, xp.int8, False),
1412
(xp.asarray(0, dtype=xp.uint8), xp.int8, False),
1513
],

0 commit comments

Comments
 (0)