Skip to content

Commit 7017797

Browse files
committed
Be more constrained about constructing symmetric matrices
I think this still might be able to construct an ill-conditioned matrix, but it didn't come up with any after thousands of example runs, let's keep it for now. Also construct symmetric matrices using (a + a.T)/2 instead of using triu/tril, and, add a meta-test for it.
1 parent fd6367f commit 7017797

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,17 @@ def mutually_broadcastable_shapes(
207207

208208
# Note: This should become hermitian_matrices when complex dtypes are added
209209
@composite
210-
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
210+
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
211211
shape = draw(square_matrix_shapes)
212212
dtype = draw(dtypes)
213213
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
214214
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
215-
upper = xp.triu(a)
216-
# mT and matrix_transpose are more likely to not be implemented
217-
lower = ah._matrix_transpose(xp.triu(a, k=1))
218-
return upper + lower
215+
at = ah._matrix_transpose(a)
216+
H = (a + at)*0.5
217+
if finite:
218+
assume(not xp.any(xp.isinf(H)))
219+
assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound))))
220+
return H
219221

220222
@composite
221223
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):

array_api_tests/meta/test_linalg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
from hypothesis import given
4+
5+
from ..hypothesis_helpers import symmetric_matrices
6+
from .. import array_helpers as ah
7+
from .. import _array_module as xp
8+
9+
@pytest.mark.xp_extension('linalg')
10+
@given(x=symmetric_matrices(finite=True))
11+
def test_symmetric_matrices(x):
12+
upper = xp.triu(x)
13+
lower = xp.tril(x)
14+
lowerT = ah._matrix_transpose(lower)
15+
16+
ah.assert_exactly_equal(upper, lowerT)

0 commit comments

Comments
 (0)