Skip to content

Commit 768ab53

Browse files
authored
Merge pull request #100 from bashtage/ref-complex-normal
Ref complex normal
2 parents 63ae94e + e022538 commit 768ab53

File tree

4 files changed

+133
-38
lines changed

4 files changed

+133
-38
lines changed

randomstate/randomstate.pyx

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ cdef double kahan_sum(double *darr, np.npy_intp n):
164164
sum = t
165165
return sum
166166

167+
cdef inline void compute_complex(double *rv_r, double *rv_i, double loc_r,
168+
double loc_i, double var_r, double var_i, double rho) nogil:
169+
cdef double scale_c, scale_i, scale_r
170+
171+
scale_c = sqrt(1 - rho * rho)
172+
scale_r = sqrt(var_r)
173+
scale_i = sqrt(var_i)
174+
175+
rv_i[0] = loc_i + scale_i * (rho * rv_r[0] + scale_c * rv_i[0])
176+
rv_r[0] = loc_r + scale_r * rv_r[0]
177+
178+
179+
167180
cdef object _ensure_string(object s):
168181
try:
169182
return ''.join(map(chr, s))
@@ -1782,9 +1795,14 @@ cdef class RandomState:
17821795
17831796
>>> s = np.random.complex_normal(size=1000)
17841797
"""
1785-
cdef np.ndarray ogamma, orelation, oloc
1798+
if method != u'zig' or method != u'bm':
1799+
raise ValueError("method must be either 'bm' or 'zig'")
1800+
cdef np.ndarray ogamma, orelation, oloc, randoms, v_real, v_imag, rho
1801+
cdef double *randoms_data
17861802
cdef double fgamma_r, fgamma_i, frelation_r, frelation_i, frho, f_v_real , f_v_imag, \
1787-
floc_r, floc_i, f_real, f_imag
1803+
floc_r, floc_i, f_real, f_imag, i_r_scale, r_scale, i_scale, f_rho
1804+
cdef np.npy_intp i, j, n
1805+
cdef np.broadcast it
17881806

17891807
oloc = <np.ndarray>np.PyArray_FROM_OTF(loc, np.NPY_COMPLEX128, np.NPY_ALIGNED)
17901808
ogamma = <np.ndarray>np.PyArray_FROM_OTF(gamma, np.NPY_COMPLEX128, np.NPY_ALIGNED)
@@ -1813,37 +1831,53 @@ cdef class RandomState:
18131831
raise ValueError('Im(relation) ** 2 > Re(gamma ** 2 - relation** 2)')
18141832

18151833
if size is None:
1816-
f_real, f_imag = self.standard_normal(size=2, method=method)
1817-
1834+
if method == u'zig':
1835+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1836+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1837+
else:
1838+
random_gauss_fill(&self.rng_state, 1, &f_real)
1839+
random_gauss_fill(&self.rng_state, 1, &f_imag)
1840+
18181841
f_imag *= sqrt(1 - f_rho * f_rho)
18191842
f_imag += f_rho * f_real
18201843
f_real *= sqrt(0.5 * f_v_real)
18211844
f_imag *= sqrt(0.5 * f_v_imag)
18221845

1823-
return PyComplex_FromDoubles(f_real, f_imag)
1846+
return PyComplex_FromDoubles(floc_r + f_real, floc_i + f_imag)
18241847

1825-
if np.PyArray_IsAnyScalar(size):
1826-
size = (size,)
1827-
else:
1828-
size = tuple(size)
1848+
randoms = <np.ndarray>np.empty(size, np.complex128)
1849+
randoms_data = <double *>np.PyArray_DATA(randoms)
1850+
n = np.PyArray_SIZE(randoms)
18291851

1830-
norms = self.standard_normal(size=size + (2,), method=method)
1831-
real = norms[...,0]
1832-
imag = norms[...,1]
1852+
i_r_scale = sqrt(1 - f_rho * f_rho)
1853+
r_scale = sqrt(0.5 * f_v_real)
1854+
i_scale = sqrt(0.5 * f_v_imag)
1855+
j = 0
1856+
with self.lock, nogil:
1857+
if method == u'zig':
1858+
for i in range(n):
1859+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
1860+
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
1861+
randoms_data[j+1] = floc_i + i_scale * (f_rho * f_real + i_r_scale * f_imag)
1862+
randoms_data[j] = floc_r + r_scale * f_real
1863+
j += 2
1864+
else:
1865+
for i in range(n):
1866+
random_gauss_fill(&self.rng_state, 1, &f_real)
1867+
random_gauss_fill(&self.rng_state, 1, &f_imag)
1868+
randoms_data[j+1] = floc_i + i_scale * (f_rho * f_real + i_r_scale * f_imag)
1869+
randoms_data[j] = floc_r + r_scale * f_real
1870+
j += 2
18331871

1834-
imag *= sqrt(1 - f_rho * f_rho)
1835-
imag += f_rho * real
1836-
real *= sqrt(0.5 * f_v_real)
1837-
imag *= sqrt(0.5 * f_v_imag)
18381872

1839-
return floc_r + real + (floc_i + imag) * (0+1.0j)
1873+
return randoms
18401874

18411875
gpc = ogamma + orelation
18421876
gmc = ogamma - orelation
1843-
v_real = 0.5 * np.real(gpc)
1877+
v_real = <np.ndarray>(0.5 * np.real(gpc))
18441878
if np.any(np.less(v_real, 0)):
18451879
raise ValueError('Re(gamma + relation) < 0')
1846-
v_imag = 0.5 * np.real(gmc)
1880+
v_imag = <np.ndarray>(0.5 * np.real(gmc))
18471881
if np.any(np.less(v_imag, 0)):
18481882
raise ValueError('Re(gamma - relation) < 0')
18491883
if np.any(np.not_equal(np.imag(ogamma), 0)):
@@ -1856,23 +1890,34 @@ cdef class RandomState:
18561890
if np.any(cov.flat[~idx] != 0) or np.any(np.abs(rho) > 1):
18571891
raise ValueError('Im(relation) ** 2 > Re(gamma ** 2 - relation ** 2)')
18581892

1859-
if size is None:
1860-
size = np.broadcast(loc, gpc).shape
1861-
elif np.PyArray_IsAnyScalar(size):
1862-
size = (size,)
1893+
if size is not None:
1894+
randoms = <np.ndarray>np.empty(size, np.complex128)
18631895
else:
1864-
size = tuple(size)
1896+
it = np.PyArray_MultiIterNew4(oloc, v_real, v_imag, rho)
1897+
randoms = <np.ndarray>np.empty(it.shape, np.complex128)
18651898

1866-
norms = self.standard_normal(size + (2,), method=method)
1867-
real = norms[...,0]
1868-
imag = norms[...,1]
1899+
randoms_data = <double *>np.PyArray_DATA(randoms)
1900+
n = np.PyArray_SIZE(randoms)
18691901

1870-
imag *= np.sqrt(1-rho ** 2)
1871-
imag += rho * real
1872-
real *= np.sqrt(v_real)
1873-
imag *= np.sqrt(v_imag)
1874-
1875-
return oloc + real + (0+1.0j) * imag
1902+
it = np.PyArray_MultiIterNew5(randoms, oloc, v_real, v_imag, rho)
1903+
with self.lock, nogil:
1904+
if method == u'zig':
1905+
random_gauss_zig_double_fill(&self.rng_state, 2 * n, randoms_data)
1906+
else:
1907+
random_gauss_fill(&self.rng_state, 2 * n, randoms_data)
1908+
with nogil:
1909+
j = 0
1910+
for i in range(n):
1911+
floc_r= (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
1912+
floc_i= (<double*>np.PyArray_MultiIter_DATA(it, 1))[1]
1913+
f_v_real = (<double*>np.PyArray_MultiIter_DATA(it, 2))[0]
1914+
f_v_imag = (<double*>np.PyArray_MultiIter_DATA(it, 3))[0]
1915+
f_rho = (<double*>np.PyArray_MultiIter_DATA(it, 4))[0]
1916+
compute_complex(&randoms_data[j], &randoms_data[j+1], floc_r, floc_i, f_v_real, f_v_imag, f_rho)
1917+
j += 2
1918+
np.PyArray_MultiIter_NEXT(it)
1919+
1920+
return randoms
18761921

18771922
def beta(self, a, b, size=None):
18781923
"""

randomstate/tests/test_against_numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ def compare_2_input(f1, f2, is_np=False, is_scalar=False):
4545
inputs = inputs[:3]
4646

4747
for i in inputs:
48-
print(i[0], i[1])
4948
v1 = f1(*i[0], **i[1])
5049
v2 = f2(*i[0], **i[1])
5150
assert_allclose(v1, v2)
52-
print('OK!' * 20)
5351

5452

5553
def compare_3_input(f1, f2, is_np=False):

randomstate/tests/test_smoke.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,60 @@ def test_chisquare(self):
287287
params_1(self.rs.chisquare)
288288

289289
def test_complex_normal(self):
290-
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10)
291-
assert_(len(vals) == 10)
290+
st = self.rs.get_state()
291+
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10, method='zig')
292+
assert_(len(vals) == 10)
293+
294+
self.rs.set_state(st)
295+
vals2 = [self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, method='zig') for _ in range(10)]
296+
np.testing.assert_allclose(vals, vals2)
297+
298+
self.rs.set_state(st)
299+
vals3 = self.rs.complex_normal(2.0 + 7.0j * np.ones(10), 10.0 * np.ones(1), 5.0 - 5.0j, method='zig')
300+
np.testing.assert_allclose(vals, vals3)
301+
302+
self.rs.set_state(st)
303+
norms = self.rs.standard_normal(size=20, method='zig')
304+
norms = np.reshape(norms, (10, 2))
305+
cov = 0.5 * (-5.0)
306+
v_real = 7.5
307+
v_imag = 2.5
308+
rho = cov / np.sqrt(v_real * v_imag)
309+
imag = 7 + np.sqrt(v_imag) * (rho * norms[:, 0] + np.sqrt(1 - rho ** 2) * norms[:, 1])
310+
real = 2 + np.sqrt(v_real) * norms[:, 0]
311+
vals4 = [re + im * (0 + 1.0j) for re, im in zip(real, imag)]
312+
313+
np.testing.assert_allclose(vals4, vals)
314+
315+
def test_complex_normal_bm(self):
316+
st = self.rs.get_state()
317+
vals = self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, size=10, method='bm')
318+
assert_(len(vals) == 10)
319+
320+
self.rs.set_state(st)
321+
vals2 = [self.rs.complex_normal(2.0 + 7.0j, 10.0, 5.0 - 5.0j, method='bm') for _ in range(10)]
322+
np.testing.assert_allclose(vals, vals2)
323+
324+
self.rs.set_state(st)
325+
vals3 = self.rs.complex_normal(2.0 + 7.0j * np.ones(10), 10.0 * np.ones(1), 5.0 - 5.0j, method='bm')
326+
np.testing.assert_allclose(vals, vals3)
327+
328+
def test_complex_normal_zero_variance(self):
329+
st = self.rs.get_state()
330+
c = self.rs.complex_normal(0, 1.0, 1.0)
331+
assert_almost_equal(c.imag, 0.0)
332+
self.rs.set_state(st)
333+
n = self.rs.standard_normal()
334+
np.testing.assert_allclose(c, n, atol=1e-8)
335+
336+
st = self.rs.get_state()
337+
c = self.rs.complex_normal(0, 1.0, -1.0)
338+
assert_almost_equal(c.real, 0.0)
339+
self.rs.set_state(st)
340+
self.rs.standard_normal()
341+
n = self.rs.standard_normal()
342+
assert_almost_equal(c.real, 0.0)
343+
np.testing.assert_allclose(c.imag, n, atol=1e-8)
292344

293345
def test_exponential(self):
294346
vals = self.rs.exponential(2.0, 10)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def cythonize(e, *args, **kwargs):
240240
with open(output_file_name, 'w') as output_file:
241241
output_file.write(template.substitute())
242242

243-
ext_modules = cythonize(extensions, force=not DEVELOP)
243+
ext_modules = cythonize(extensions, force=not DEVELOP, annotate=True)
244244

245245
classifiers = ['Development Status :: 5 - Production/Stable',
246246
'Environment :: Console',

0 commit comments

Comments
 (0)