Skip to content

Commit 7c89cf1

Browse files
authored
Merge pull request #219 from jakevdp/symmetric-matrices
symmetric_matrices: draw from finite if necessary
2 parents 6a38365 + 54557ae commit 7c89cf1

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def mutually_broadcastable_shapes(
251251
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
252252
shape = draw(square_matrix_shapes)
253253
dtype = draw(dtypes)
254+
if not isinstance(finite, bool):
255+
finite = draw(finite)
254256
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
255257
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
256258
upper = xp.triu(a)

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,9 @@ def run(n, d, data):
129129

130130

131131

132-
@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
133-
finite=st.shared(st.booleans(), key='finite')),
134-
dtype=hh.shared_floating_dtypes,
135-
finite=st.shared(st.booleans(), key='finite'))
136-
def test_symmetric_matrices(m, dtype, finite):
132+
@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data())
133+
def test_symmetric_matrices(finite, dtype, data):
134+
m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite))
137135
assert m.dtype == dtype
138136
# TODO: This part of this test should be part of the .mT test
139137
ah.assert_exactly_equal(m, m.mT)

0 commit comments

Comments
 (0)