Skip to content

Commit 47894ff

Browse files
committed
Add 2023.12 axis restrictions to vecdot() and cross()
1 parent 8572df3 commit 47894ff

File tree

4 files changed

+161
-2
lines changed

4 files changed

+161
-2
lines changed

array_api_strict/_linear_algebra_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from __future__ import annotations
99

1010
from ._dtypes import _numeric_dtypes
11-
1211
from ._array_object import Array
12+
from ._flags import get_array_api_strict_flags
1313

1414
from typing import TYPE_CHECKING
1515
if TYPE_CHECKING:
@@ -54,6 +54,19 @@ def matrix_transpose(x: Array, /) -> Array:
5454
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
5555
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
5656
raise TypeError('Only numeric dtypes are allowed in vecdot')
57+
58+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
59+
if axis >= 0:
60+
raise ValueError("axis must be negative in vecdot")
61+
elif axis < min(-1, -x1.ndim, -x2.ndim):
62+
raise ValueError("axis is out of bounds for x1 and x2")
63+
64+
# In versions if the standard prior to 2023.12, vecdot applied axis after
65+
# broadcasting. This is different from applying it before broadcasting
66+
# when axis is nonnegative. The below code keeps this behavior for
67+
# 2022.12, primarily for backwards compatibility. Note that the behavior
68+
# is unambiguous when axis is negative, so the below code should work
69+
# correctly in that case regardless of which version is used.
5770
ndim = max(x1.ndim, x2.ndim)
5871
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
5972
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)

array_api_strict/linalg.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
8080
# Note: this is different from np.cross(), which allows dimension 2
8181
if x1.shape[axis] != 3:
8282
raise ValueError('cross() dimension must equal 3')
83+
84+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
85+
if axis >= 0:
86+
raise ValueError("axis must be negative in cross")
87+
elif axis < min(-1, -x1.ndim, -x2.ndim):
88+
raise ValueError("axis is out of bounds for x1 and x2")
89+
90+
# Prior to 2023.12, there was ambiguity in the standard about whether
91+
# positive axis applied before or after broadcasting. NumPy applies
92+
# the axis before broadcasting. Since that behavior is what has always
93+
# been implemented here, we keep it for backwards compatibility.
8394
return Array._new(np.cross(x1._array, x2._array, axis=axis))
8495

8596
@requires_extension('linalg')

array_api_strict/tests/test_linalg.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
3+
from .._flags import set_array_api_strict_flags
4+
5+
import array_api_strict as xp
6+
7+
# TODO: Maybe all of these exceptions should be IndexError?
8+
9+
# Technically this is linear_algebra, not linalg, but it's simpler to keep
10+
# both of these tests together
11+
def test_vecdot_2023_12():
12+
# Test the axis < 0 restriction for 2023.12, and also the 2022.12 axis >=
13+
# 0 behavior (which is primarily kept for backwards compatibility).
14+
15+
a = xp.ones((2, 3, 4, 5))
16+
b = xp.ones(( 3, 4, 1))
17+
18+
# 2022.12 behavior, which is to apply axis >= 0 after broadcasting
19+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
20+
assert xp.linalg.vecdot(a, b, axis=1).shape == (2, 4, 5)
21+
assert xp.linalg.vecdot(a, b, axis=2).shape == (2, 3, 5)
22+
# This is disallowed because the arrays must have the same values before
23+
# broadcasting
24+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1))
25+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4))
26+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3))
27+
28+
# Out-of-bounds axes even after broadcasting
29+
pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=4))
30+
pytest.raises(IndexError, lambda: xp.linalg.vecdot(a, b, axis=-5))
31+
32+
# negative axis behavior is unambiguous when it's within the bounds of
33+
# both arrays before broadcasting
34+
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
35+
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)
36+
37+
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
38+
# min(x1.ndim, x2.ndim), which is unambiguous
39+
with pytest.warns(UserWarning):
40+
set_array_api_strict_flags(api_version='2023.12')
41+
42+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=0))
43+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=1))
44+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=2))
45+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=3))
46+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-1))
47+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-4))
48+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=4))
49+
pytest.raises(ValueError, lambda: xp.linalg.vecdot(a, b, axis=-5))
50+
51+
assert xp.linalg.vecdot(a, b, axis=-2).shape == (2, 3, 5)
52+
assert xp.linalg.vecdot(a, b, axis=-3).shape == (2, 4, 5)
53+
54+
@pytest.mark.parametrize('api_version', ['2021.12', '2022.12', '2023.12'])
55+
def test_cross(api_version):
56+
# This test tests everything that should be the same across all supported
57+
# API versions.
58+
59+
if api_version != '2022.12':
60+
with pytest.warns(UserWarning):
61+
set_array_api_strict_flags(api_version=api_version)
62+
else:
63+
set_array_api_strict_flags(api_version=api_version)
64+
65+
a = xp.ones((2, 4, 5, 3))
66+
b = xp.ones(( 4, 1, 3))
67+
assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3)
68+
69+
a = xp.ones((2, 4, 3, 5))
70+
b = xp.ones(( 4, 3, 1))
71+
assert xp.linalg.cross(a, b, axis=-2).shape == (2, 4, 3, 5)
72+
73+
# This is disallowed because the axes must equal 3 before broadcasting
74+
a = xp.ones((3, 2, 3, 5))
75+
b = xp.ones(( 2, 1, 1))
76+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-1))
77+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2))
78+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3))
79+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4))
80+
81+
# Out-of-bounds axes even after broadcasting
82+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4))
83+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5))
84+
85+
@pytest.mark.parametrize('api_version', ['2021.12', '2022.12'])
86+
def test_cross_2022_12(api_version):
87+
# Test the 2022.12 axis >= 0 behavior, which is primarily kept for
88+
# backwards compatibility. Note that unlike vecdot, array_api_strict
89+
# cross() never implemented the "after broadcasting" axis behavior, but
90+
# just reused NumPy cross(), which applies axes before broadcasting.
91+
if api_version != '2022.12':
92+
with pytest.warns(UserWarning):
93+
set_array_api_strict_flags(api_version=api_version)
94+
else:
95+
set_array_api_strict_flags(api_version=api_version)
96+
97+
a = xp.ones((3, 2, 4, 5))
98+
b = xp.ones((3, 2, 4, 1))
99+
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)
100+
101+
# ambiguous case
102+
a = xp.ones(( 3, 4, 5))
103+
b = xp.ones((3, 2, 4, 1))
104+
assert xp.linalg.cross(a, b, axis=0).shape == (3, 2, 4, 5)
105+
106+
def test_cross_2023_12():
107+
# 2023.12 behavior, which is to only allow axis < 0 and axis >=
108+
# min(x1.ndim, x2.ndim), which is unambiguous
109+
with pytest.warns(UserWarning):
110+
set_array_api_strict_flags(api_version='2023.12')
111+
112+
a = xp.ones((3, 2, 4, 5))
113+
b = xp.ones((3, 2, 4, 1))
114+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))
115+
116+
a = xp.ones(( 3, 4, 5))
117+
b = xp.ones((3, 2, 4, 1))
118+
pytest.raises(ValueError, lambda: xp. linalg.cross(a, b, axis=0))
119+
120+
a = xp.ones((2, 4, 5, 3))
121+
b = xp.ones(( 4, 1, 3))
122+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=0))
123+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=1))
124+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=2))
125+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=3))
126+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-2))
127+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-3))
128+
pytest.raises(ValueError, lambda: xp.linalg.cross(a, b, axis=-4))
129+
130+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=4))
131+
pytest.raises(IndexError, lambda: xp.linalg.cross(a, b, axis=-5))
132+
133+
assert xp.linalg.cross(a, b, axis=-1).shape == (2, 4, 5, 3)

array_api_strict/tests/test_statistical_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
from .._flags import set_array_api_strict_flags
4+
35
import array_api_strict as xp
46

57
@pytest.mark.parametrize('func_name', ['sum', 'prod', 'trace'])
@@ -20,7 +22,7 @@ def test_sum_prod_trace_2023_12(func_name):
2022
assert func(a_int).dtype == xp.int64
2123

2224
with pytest.warns(UserWarning):
23-
xp.set_array_api_strict_flags(api_version='2023.12')
25+
set_array_api_strict_flags(api_version='2023.12')
2426

2527
assert func(a_real).dtype == xp.float32
2628
assert func(a_complex).dtype == xp.complex64

0 commit comments

Comments
 (0)