Skip to content

Commit bb03810

Browse files
committed
ENH: Add support fo ziggurat method for all normal
Add support for both zigurat and inv method for all normal generators
1 parent 601c87b commit bb03810

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

randomstate/interface.pyx

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ cdef class RandomState:
11121112

11131113
def randn(self, *args, method=__normal_method):
11141114
"""
1115-
randn(d0, d1, ..., dn)
1115+
randn(d0, d1, ..., dn, method='inv')
11161116
11171117
Return a sample (or samples) from the "standard normal" distribution.
11181118
@@ -1132,6 +1132,9 @@ cdef class RandomState:
11321132
d0, d1, ..., dn : int, optional
11331133
The dimensions of the returned array, should be all positive.
11341134
If no argument is given a single Python float is returned.
1135+
method : str, optional
1136+
Either 'inv' or 'zig'. 'inv' uses the default FIXME method. 'zig' uses
1137+
the much faster ziggurat method of FIXME.
11351138
11361139
Returns
11371140
-------
@@ -3674,9 +3677,9 @@ cdef class RandomState:
36743677
0.0, '', CONS_NONE)
36753678

36763679
# Multivariate distributions:
3677-
def multivariate_normal(self, mean, cov, size=None):
3680+
def multivariate_normal(self, mean, cov, size=None, method=__normal_method):
36783681
"""
3679-
multivariate_normal(mean, cov[, size])
3682+
multivariate_normal(mean, cov, size=None, method='inv')
36803683
36813684
Draw random samples from a multivariate normal distribution.
36823685
@@ -3699,6 +3702,9 @@ cdef class RandomState:
36993702
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because
37003703
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``.
37013704
If no shape is specified, a single (`N`-D) sample is returned.
3705+
method : str, optional
3706+
Either 'inv' or 'zig'. 'inv' uses the default FIXME method. 'zig' uses
3707+
the much faster ziggurat method of FIXME.
37023708
37033709
Returns
37043710
-------
@@ -3795,7 +3801,7 @@ cdef class RandomState:
37953801
# form a matrix of shape final_shape.
37963802
final_shape = tuple(shape[:])
37973803
final_shape += (mean.shape[0],)
3798-
x = self.standard_normal(final_shape).reshape(-1, mean.shape[0])
3804+
x = self.standard_normal(final_shape, method=method).reshape(-1, mean.shape[0])
37993805

38003806
# Transform matrix of standard normals into matrix where each row
38013807
# contains multivariate normals with the desired covariance.

randomstate/tests/test_smoke.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,22 @@ def test_rand(self):
304304

305305
def test_randn(self):
306306
state = self.rs.get_state()
307-
print(state)
308307
vals = self.rs.randn(10, 10, 10)
309308
self.rs.set_state(state)
310-
print(self.rs.get_state())
311309
assert_equal(vals, self.rs.standard_normal((10, 10, 10)))
312310
assert_equal(vals.shape, (10, 10, 10))
313311

312+
state = self.rs.get_state()
313+
vals = self.rs.randn(10, 10, 10, method='inv')
314+
self.rs.set_state(state)
315+
assert_equal(vals, self.rs.standard_normal((10, 10, 10), method='inv'))
316+
317+
state = self.rs.get_state()
318+
vals_inv = self.rs.randn(10, 10, 10, method='inv')
319+
self.rs.set_state(state)
320+
vals_zig = self.rs.randn(10, 10, 10, method='zig')
321+
assert (vals_zig != vals_inv).any()
322+
314323
def test_noncentral_chisquare(self):
315324
vals = self.rs.noncentral_chisquare(10, 2, 10)
316325
assert len(vals) == 10
@@ -400,6 +409,11 @@ def test_multivariate_normal(self):
400409
cov = [[1, 0], [0, 100]] # diagonal covariance
401410
x = self.rs.multivariate_normal(mean, cov, 5000)
402411
assert x.shape == (5000, 2)
412+
x_zig = self.rs.multivariate_normal(mean, cov, 5000, method='zig')
413+
assert x.shape == (5000, 2)
414+
x_inv = self.rs.multivariate_normal(mean, cov, 5000, method='inv')
415+
assert x.shape == (5000, 2)
416+
assert (x_zig != x_inv).any()
403417

404418
def test_multinomial(self):
405419
vals = self.rs.multinomial(100, [1.0 / 3, 2.0 / 3])

0 commit comments

Comments
 (0)