Skip to content

Commit f2e38db

Browse files
authored
ENH: stats.special_ortho_group: speed up, allow 1x1 and 0x0 ortho and unitary groups (scipy#22304)
* ENH: streamline implementation of stats.special_ortho_group * ENH: allow 1x1 and 0x0 ortho and unitary groups * TST: switch from random.seed to random.default_rng
1 parent 3b386f4 commit f2e38db

File tree

2 files changed

+60
-81
lines changed

2 files changed

+60
-81
lines changed

scipy/stats/_multivariate.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3620,9 +3620,9 @@ def __call__(self, dim=None, seed=None):
36203620

36213621
def _process_parameters(self, dim):
36223622
"""Dimension N must be specified; it cannot be inferred."""
3623-
if dim is None or not np.isscalar(dim) or dim <= 1 or dim != int(dim):
3623+
if dim is None or not np.isscalar(dim) or dim < 0 or dim != int(dim):
36243624
raise ValueError("""Dimension of rotation must be specified,
3625-
and must be a scalar greater than 1.""")
3625+
and must be a scalar nonnegative integer.""")
36263626

36273627
return dim
36283628

@@ -3644,55 +3644,11 @@ def rvs(self, dim, size=1, random_state=None):
36443644
"""
36453645
random_state = self._get_random_state(random_state)
36463646

3647-
size = int(size)
3648-
size = (size,) if size > 1 else ()
3649-
3650-
dim = self._process_parameters(dim)
3651-
3652-
# H represents a (dim, dim) matrix, while D represents the diagonal of
3653-
# a (dim, dim) diagonal matrix. The algorithm that follows is
3654-
# broadcasted on the leading shape in `size` to vectorize along
3655-
# samples.
3656-
H = np.empty(size + (dim, dim))
3657-
H[..., :, :] = np.eye(dim)
3658-
D = np.empty(size + (dim,))
3659-
3660-
for n in range(dim-1):
3661-
3662-
# x is a vector with length dim-n, xrow and xcol are views of it as
3663-
# a row vector and column vector respectively. It's important they
3664-
# are views and not copies because we are going to modify x
3665-
# in-place.
3666-
x = random_state.normal(size=size + (dim-n,))
3667-
xrow = x[..., None, :]
3668-
xcol = x[..., :, None]
3669-
3670-
# This is the squared norm of x, without vectorization it would be
3671-
# dot(x, x), to have proper broadcasting we use matmul and squeeze
3672-
# out (convert to scalar) the resulting 1x1 matrix
3673-
norm2 = np.matmul(xrow, xcol).squeeze((-2, -1))
3674-
3675-
x0 = x[..., 0].copy()
3676-
D[..., n] = np.where(x0 != 0, np.sign(x0), 1)
3677-
x[..., 0] += D[..., n]*np.sqrt(norm2)
3678-
3679-
# In renormalizing x we have to append an additional axis with
3680-
# [..., None] to broadcast the scalar against the vector x
3681-
x /= np.sqrt((norm2 - x0**2 + x[..., 0]**2) / 2.)[..., None]
3682-
3683-
# Householder transformation, without vectorization the RHS can be
3684-
# written as outer(H @ x, x) (apart from the slicing)
3685-
H[..., :, n:] -= np.matmul(H[..., :, n:], xcol) * xrow
3686-
3687-
D[..., -1] = (-1)**(dim-1)*D[..., :-1].prod(axis=-1)
3688-
3689-
# Without vectorization this could be written as H = diag(D) @ H,
3690-
# left-multiplication by a diagonal matrix amounts to multiplying each
3691-
# row of H by an element of the diagonal, so we add a dummy axis for
3692-
# the column index
3693-
H *= D[..., :, None]
3694-
return H
3695-
3647+
q = ortho_group.rvs(dim, size, random_state)
3648+
dets = np.linalg.det(q)
3649+
if dim:
3650+
q[..., 0, :] /= dets[..., np.newaxis]
3651+
return q
36963652

36973653
special_ortho_group = special_ortho_group_gen()
36983654

@@ -3807,9 +3763,9 @@ def __call__(self, dim=None, seed=None):
38073763

38083764
def _process_parameters(self, dim):
38093765
"""Dimension N must be specified; it cannot be inferred."""
3810-
if dim is None or not np.isscalar(dim) or dim <= 1 or dim != int(dim):
3766+
if dim is None or not np.isscalar(dim) or dim < 0 or dim != int(dim):
38113767
raise ValueError("Dimension of rotation must be specified,"
3812-
"and must be a scalar greater than 1.")
3768+
"and must be a scalar nonnegative integer.")
38133769

38143770
return dim
38153771

@@ -4162,7 +4118,7 @@ class unitary_group_gen(multi_rv_generic):
41624118
Parameters
41634119
----------
41644120
dim : scalar
4165-
Dimension of matrices, must be greater than 1.
4121+
Dimension of matrices.
41664122
seed : {None, int, np.random.RandomState, np.random.Generator}, optional
41674123
Used for drawing random variates.
41684124
If `seed` is `None`, the `~np.random.RandomState` singleton is used.
@@ -4219,9 +4175,9 @@ def __call__(self, dim=None, seed=None):
42194175

42204176
def _process_parameters(self, dim):
42214177
"""Dimension N must be specified; it cannot be inferred."""
4222-
if dim is None or not np.isscalar(dim) or dim <= 1 or dim != int(dim):
4178+
if dim is None or not np.isscalar(dim) or dim < 0 or dim != int(dim):
42234179
raise ValueError("Dimension of rotation must be specified,"
4224-
"and must be a scalar greater than 1.")
4180+
"and must be a scalar nonnegative integer.")
42254181

42264182
return dim
42274183

scipy/stats/tests/test_multivariate.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,21 +1930,16 @@ def test_logpdf_4x4(self):
19301930

19311931
class TestSpecialOrthoGroup:
19321932
def test_reproducibility(self):
1933-
np.random.seed(514)
1934-
x = special_ortho_group.rvs(3)
1935-
expected = np.array([[-0.99394515, -0.04527879, 0.10011432],
1936-
[0.04821555, -0.99846897, 0.02711042],
1937-
[0.09873351, 0.03177334, 0.99460653]])
1938-
assert_array_almost_equal(x, expected)
1939-
1940-
random_state = np.random.RandomState(seed=514)
1941-
x = special_ortho_group.rvs(3, random_state=random_state)
1933+
x = special_ortho_group.rvs(3, random_state=np.random.default_rng(514))
1934+
expected = np.array([[-0.93200988, 0.01533561, -0.36210826],
1935+
[0.35742128, 0.20446501, -0.91128705],
1936+
[0.06006333, -0.97875374, -0.19604469]])
19421937
assert_array_almost_equal(x, expected)
19431938

19441939
def test_invalid_dim(self):
19451940
assert_raises(ValueError, special_ortho_group.rvs, None)
19461941
assert_raises(ValueError, special_ortho_group.rvs, (2, 2))
1947-
assert_raises(ValueError, special_ortho_group.rvs, 1)
1942+
assert_raises(ValueError, special_ortho_group.rvs, -1)
19481943
assert_raises(ValueError, special_ortho_group.rvs, 2.5)
19491944

19501945
def test_frozen_matrix(self):
@@ -1979,8 +1974,9 @@ def test_haar(self):
19791974
dim = 5
19801975
samples = 1000 # Not too many, or the test takes too long
19811976
ks_prob = .05
1982-
np.random.seed(514)
1983-
xs = special_ortho_group.rvs(dim, size=samples)
1977+
xs = special_ortho_group.rvs(
1978+
dim, size=samples, random_state=np.random.default_rng(513)
1979+
)
19841980

19851981
# Dot a few rows (0, 1, 2) with unit vectors (0, 2, 4, 3),
19861982
# effectively picking off entries in the matrices of xs.
@@ -1997,6 +1993,14 @@ def test_haar(self):
19971993
ks_tests = [ks_2samp(proj[p0], proj[p1])[1] for (p0, p1) in pairs]
19981994
assert_array_less([ks_prob]*len(pairs), ks_tests)
19991995

1996+
def test_one_by_one(self):
1997+
# Test that the distribution is a delta function at the identity matrix
1998+
# when dim=1
1999+
assert_allclose(special_ortho_group.rvs(1, size=1000), 1, rtol=1e-13)
2000+
2001+
def test_zero_by_zero(self):
2002+
assert_equal(special_ortho_group.rvs(0, size=4).shape, (4, 0, 0))
2003+
20002004

20012005
class TestOrthoGroup:
20022006
def test_reproducibility(self):
@@ -2015,7 +2019,7 @@ def test_reproducibility(self):
20152019
def test_invalid_dim(self):
20162020
assert_raises(ValueError, ortho_group.rvs, None)
20172021
assert_raises(ValueError, ortho_group.rvs, (2, 2))
2018-
assert_raises(ValueError, ortho_group.rvs, 1)
2022+
assert_raises(ValueError, ortho_group.rvs, -1)
20192023
assert_raises(ValueError, ortho_group.rvs, 2.5)
20202024

20212025
def test_frozen_matrix(self):
@@ -2085,6 +2089,20 @@ def test_haar(self):
20852089
ks_tests = [ks_2samp(proj[p0], proj[p1])[1] for (p0, p1) in pairs]
20862090
assert_array_less([ks_prob]*len(pairs), ks_tests)
20872091

2092+
def test_one_by_one(self):
2093+
# Test that the 1x1 distribution gives ±1 with equal probability.
2094+
dim = 1
2095+
xs = ortho_group.rvs(dim, size=5000, random_state=np.random.default_rng(514))
2096+
assert_allclose(np.abs(xs), 1, rtol=1e-13)
2097+
k = np.sum(xs > 0)
2098+
n = len(xs)
2099+
res = stats.binomtest(k, n)
2100+
low, high = res.proportion_ci(confidence_level=0.95)
2101+
assert low < 0.5 < high
2102+
2103+
def test_zero_by_zero(self):
2104+
assert_equal(special_ortho_group.rvs(0, size=4).shape, (4, 0, 0))
2105+
20882106
@pytest.mark.slow
20892107
def test_pairwise_distances(self):
20902108
# Test that the distribution of pairwise distances is close to correct.
@@ -2290,7 +2308,7 @@ def test_reproducibility(self):
22902308
def test_invalid_dim(self):
22912309
assert_raises(ValueError, unitary_group.rvs, None)
22922310
assert_raises(ValueError, unitary_group.rvs, (2, 2))
2293-
assert_raises(ValueError, unitary_group.rvs, 1)
2311+
assert_raises(ValueError, unitary_group.rvs, -1)
22942312
assert_raises(ValueError, unitary_group.rvs, 2.5)
22952313

22962314
def test_frozen_matrix(self):
@@ -2319,17 +2337,22 @@ def test_haar(self):
23192337
# the complex plane, are uncorrelated.
23202338

23212339
# Generate samples
2322-
dim = 5
2323-
samples = 1000 # Not too many, or the test takes too long
2324-
np.random.seed(514) # Note that the test is sensitive to seed too
2325-
xs = unitary_group.rvs(dim, size=samples)
2326-
2327-
# The angles "x" of the eigenvalues should be uniformly distributed
2328-
# Overall this seems to be a necessary but weak test of the distribution.
2329-
eigs = np.vstack([scipy.linalg.eigvals(x) for x in xs])
2330-
x = np.arctan2(eigs.imag, eigs.real)
2331-
res = kstest(x.ravel(), uniform(-np.pi, 2*np.pi).cdf)
2332-
assert_(res.pvalue > 0.05)
2340+
for dim in (1, 5):
2341+
samples = 1000 # Not too many, or the test takes too long
2342+
# Note that the test is sensitive to seed too
2343+
xs = unitary_group.rvs(
2344+
dim, size=samples, random_state=np.random.default_rng(514)
2345+
)
2346+
2347+
# The angles "x" of the eigenvalues should be uniformly distributed
2348+
# Overall this seems to be a necessary but weak test of the distribution.
2349+
eigs = np.vstack([scipy.linalg.eigvals(x) for x in xs])
2350+
x = np.arctan2(eigs.imag, eigs.real)
2351+
res = kstest(x.ravel(), uniform(-np.pi, 2*np.pi).cdf)
2352+
assert_(res.pvalue > 0.05)
2353+
2354+
def test_zero_by_zero(self):
2355+
assert_equal(unitary_group.rvs(0, size=4).shape, (4, 0, 0))
23332356

23342357

23352358
class TestMultivariateT:

0 commit comments

Comments
 (0)