Skip to content

Commit e022538

Browse files
committed
TST: Add testing for complex normal
Add testing for complex normal for alternative paths
1 parent 8958b20 commit e022538

File tree

3 files changed

+83
-16
lines changed

3 files changed

+83
-16
lines changed

randomstate/randomstate.pyx

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,11 +1795,12 @@ cdef class RandomState:
17951795
17961796
>>> s = np.random.complex_normal(size=1000)
17971797
"""
1798+
if method != u'zig' or method != u'bm':
1799+
raise ValueError("method must be either 'bm' or 'zig'")
17981800
cdef np.ndarray ogamma, orelation, oloc, randoms, v_real, v_imag, rho
17991801
cdef double *randoms_data
18001802
cdef double fgamma_r, fgamma_i, frelation_r, frelation_i, frho, f_v_real , f_v_imag, \
18011803
floc_r, floc_i, f_real, f_imag, i_r_scale, r_scale, i_scale, f_rho
1802-
cdef complex cloc
18031804
cdef np.npy_intp i, j, n
18041805
cdef np.broadcast it
18051806

@@ -1830,8 +1831,13 @@ cdef class RandomState:
18301831
raise ValueError('Im(relation) ** 2 > Re(gamma ** 2 - relation** 2)')
18311832

18321833
if size is None:
1833-
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1834-
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1834+
if method == u'zig':
1835+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1836+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1837+
else:
1838+
random_gauss_fill(&self.rng_state, 1, &f_real)
1839+
random_gauss_fill(&self.rng_state, 1, &f_imag)
1840+
18351841
f_imag *= sqrt(1 - f_rho * f_rho)
18361842
f_imag += f_rho * f_real
18371843
f_real *= sqrt(0.5 * f_v_real)
@@ -1847,13 +1853,22 @@ cdef class RandomState:
18471853
r_scale = sqrt(0.5 * f_v_real)
18481854
i_scale = sqrt(0.5 * f_v_imag)
18491855
j = 0
1850-
with self.lock: # , nogil:
1851-
for i in range(n):
1852-
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1853-
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1854-
randoms_data[j+1] = floc_i + i_scale * (f_rho * f_real + i_r_scale * f_imag)
1855-
randoms_data[j] = floc_r + r_scale * f_real
1856-
j += 2
1856+
with self.lock, nogil:
1857+
if method == u'zig':
1858+
for i in range(n):
1859+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1860+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1861+
randoms_data[j+1] = floc_i + i_scale * (f_rho * f_real + i_r_scale * f_imag)
1862+
randoms_data[j] = floc_r + r_scale * f_real
1863+
j += 2
1864+
else:
1865+
for i in range(n):
1866+
random_gauss_fill(&self.rng_state, 1, &f_real)
1867+
random_gauss_fill(&self.rng_state, 1, &f_imag)
1868+
randoms_data[j+1] = floc_i + i_scale * (f_rho * f_real + i_r_scale * f_imag)
1869+
randoms_data[j] = floc_r + r_scale * f_real
1870+
j += 2
1871+
18571872

18581873
return randoms
18591874

@@ -1885,9 +1900,11 @@ cdef class RandomState:
18851900
n = np.PyArray_SIZE(randoms)
18861901

18871902
it = np.PyArray_MultiIterNew5(randoms, oloc, v_real, v_imag, rho)
1888-
# TODO: Box-Muller
18891903
with self.lock, nogil:
1890-
random_gauss_zig_double_fill(&self.rng_state, 2 * n, randoms_data)
1904+
if method == u'zig':
1905+
random_gauss_zig_double_fill(&self.rng_state, 2 * n, randoms_data)
1906+
else:
1907+
random_gauss_fill(&self.rng_state, 2 * n, randoms_data)
18911908
with nogil:
18921909
j = 0
18931910
for i in range(n):

randomstate/tests/test_against_numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ def compare_2_input(f1, f2, is_np=False, is_scalar=False):
4545
inputs = inputs[:3]
4646

4747
for i in inputs:
48-
print(i[0], i[1])
4948
v1 = f1(*i[0], **i[1])
5049
v2 = f2(*i[0], **i[1])
5150
assert_allclose(v1, v2)
52-
print('OK!' * 20)
5351

5452

5553
def compare_3_input(f1, f2, is_np=False):

randomstate/tests/test_smoke.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,60 @@ def test_chisquare(self):
287287
params_1(self.rs.chisquare)
288288

289289
def test_complex_normal(self):
290-
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10)
291-
assert_(len(vals) == 10)
290+
st = self.rs.get_state()
291+
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10, method='zig')
292+
assert_(len(vals) == 10)
293+
294+
self.rs.set_state(st)
295+
vals2 = [self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, method='zig') for _ in range(10)]
296+
np.testing.assert_allclose(vals, vals2)
297+
298+
self.rs.set_state(st)
299+
vals3 = self.rs.complex_normal(2.0 + 7.0j * np.ones(10), 10.0 * np.ones(1), 5.0 - 5.0j, method='zig')
300+
np.testing.assert_allclose(vals, vals3)
301+
302+
self.rs.set_state(st)
303+
norms = self.rs.standard_normal(size=20, method='zig')
304+
norms = np.reshape(norms, (10, 2))
305+
cov = 0.5 * (-5.0)
306+
v_real = 7.5
307+
v_imag = 2.5
308+
rho = cov / np.sqrt(v_real * v_imag)
309+
imag = 7 + np.sqrt(v_imag) * (rho * norms[:, 0] + np.sqrt(1 - rho ** 2) * norms[:, 1])
310+
real = 2 + np.sqrt(v_real) * norms[:, 0]
311+
vals4 = [re + im * (0 + 1.0j) for re, im in zip(real, imag)]
312+
313+
np.testing.assert_allclose(vals4, vals)
314+
315+
def test_complex_normal_bm(self):
316+
st = self.rs.get_state()
317+
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10, method='bm')
318+
assert_(len(vals) == 10)
319+
320+
self.rs.set_state(st)
321+
vals2 = [self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, method='bm') for _ in range(10)]
322+
np.testing.assert_allclose(vals, vals2)
323+
324+
self.rs.set_state(st)
325+
vals3 = self.rs.complex_normal(2.0 + 7.0j * np.ones(10), 10.0 * np.ones(1), 5.0 - 5.0j, method='bm')
326+
np.testing.assert_allclose(vals, vals3)
327+
328+
def test_complex_normal_zero_variance(self):
329+
st = self.rs.get_state()
330+
c = self.rs.complex_normal(0, 1.0, 1.0)
331+
assert_almost_equal(c.imag, 0.0)
332+
self.rs.set_state(st)
333+
n = self.rs.standard_normal()
334+
np.testing.assert_allclose(c, n, atol=1e-8)
335+
336+
st = self.rs.get_state()
337+
c = self.rs.complex_normal(0, 1.0, -1.0)
338+
assert_almost_equal(c.real, 0.0)
339+
self.rs.set_state(st)
340+
self.rs.standard_normal()
341+
n = self.rs.standard_normal()
342+
assert_almost_equal(c.real, 0.0)
343+
np.testing.assert_allclose(c.imag, n, atol=1e-8)
292344

293345
def test_exponential(self):
294346
vals = self.rs.exponential(2.0, 10)

0 commit comments

Comments
 (0)