Skip to content

Commit fd6367f

Browse files
committed
Use a more robust fallback helper for matrix_transpose
1 parent a96a5df commit fd6367f

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

array_api_tests/array_helpers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# These are exported here so that they can be included in the special cases
88
# tests from this file.
99
from ._array_module import logical_not, subtract, floor, ceil, where
10+
from . import _array_module as xp
1011
from . import dtype_helpers as dh
1112

1213
from ndindex import iter_indices
@@ -345,3 +346,14 @@ def same_sign(x, y):
345346

346347
def assert_same_sign(x, y):
347348
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
349+
350+
def _matrix_transpose(x):
351+
if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
352+
return xp.matrix_transpose(x)
353+
if hasattr(x, 'mT'):
354+
return x.mT
355+
if not isinstance(xp.permute_dims, xp._UndefinedStub):
356+
perm = list(range(x.ndim))
357+
perm[-1], perm[-2] = perm[-2], perm[-1]
358+
return xp.permute_dims(x, axes=tuple(perm))
359+
raise NotImplementedError("No way to compute matrix transpose")

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
sampled_from, shared)
1111

1212
from . import _array_module as xp
13+
from . import array_helpers as ah
1314
from . import dtype_helpers as dh
1415
from . import shape_helpers as sh
1516
from . import xps
@@ -212,7 +213,8 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
212213
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
213214
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
214215
upper = xp.triu(a)
215-
lower = xp.triu(a, k=1).mT
216+
# mT and matrix_transpose are more likely to not be implemented
217+
lower = ah._matrix_transpose(xp.triu(a, k=1))
216218
return upper + lower
217219

218220
@composite

0 commit comments

Comments
 (0)