Skip to content

Commit 8958b20

Browse files
committed
ENH: Complete complex normal refactor
Refactor broadcasted generation
1 parent b4df905 commit 8958b20

File tree

1 file changed

+44
-20
lines changed

1 file changed

+44
-20
lines changed

randomstate/randomstate.pyx

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ cdef double kahan_sum(double *darr, np.npy_intp n):
164164
sum = t
165165
return sum
166166

167+
cdef inline void compute_complex(double *rv_r, double *rv_i, double loc_r,
168+
double loc_i, double var_r, double var_i, double rho) nogil:
169+
cdef double scale_c, scale_i, scale_r
170+
171+
scale_c = sqrt(1 - rho * rho)
172+
scale_r = sqrt(var_r)
173+
scale_i = sqrt(var_i)
174+
175+
rv_i[0] = loc_i + scale_i * (rho * rv_r[0] + scale_c * rv_i[0])
176+
rv_r[0] = loc_r + scale_r * rv_r[0]
177+
178+
179+
167180
cdef object _ensure_string(object s):
168181
try:
169182
return ''.join(map(chr, s))
@@ -1782,11 +1795,13 @@ cdef class RandomState:
17821795
17831796
>>> s = np.random.complex_normal(size=1000)
17841797
"""
1785-
cdef np.ndarray ogamma, orelation, oloc, randoms
1798+
cdef np.ndarray ogamma, orelation, oloc, randoms, v_real, v_imag, rho
17861799
cdef double *randoms_data
17871800
cdef double fgamma_r, fgamma_i, frelation_r, frelation_i, frho, f_v_real , f_v_imag, \
1788-
floc_r, floc_i, f_real, f_imag, i_r_scale, r_scale, i_scale
1789-
cdef np.npy_intp i, j, n,
1801+
floc_r, floc_i, f_real, f_imag, i_r_scale, r_scale, i_scale, f_rho
1802+
cdef complex cloc
1803+
cdef np.npy_intp i, j, n
1804+
cdef np.broadcast it
17901805

17911806
oloc = <np.ndarray>np.PyArray_FROM_OTF(loc, np.NPY_COMPLEX128, np.NPY_ALIGNED)
17921807
ogamma = <np.ndarray>np.PyArray_FROM_OTF(gamma, np.NPY_COMPLEX128, np.NPY_ALIGNED)
@@ -1832,7 +1847,7 @@ cdef class RandomState:
18321847
r_scale = sqrt(0.5 * f_v_real)
18331848
i_scale = sqrt(0.5 * f_v_imag)
18341849
j = 0
1835-
with self.lock, nogil:
1850+
with self.lock: # , nogil:
18361851
for i in range(n):
18371852
random_gauss_zig_double_fill(&self.rng_state, 1, &f_real)
18381853
random_gauss_zig_double_fill(&self.rng_state, 1, &f_imag)
@@ -1844,10 +1859,10 @@ cdef class RandomState:
18441859

18451860
gpc = ogamma + orelation
18461861
gmc = ogamma - orelation
1847-
v_real = 0.5 * np.real(gpc)
1862+
v_real = <np.ndarray>(0.5 * np.real(gpc))
18481863
if np.any(np.less(v_real, 0)):
18491864
raise ValueError('Re(gamma + relation) < 0')
1850-
v_imag = 0.5 * np.real(gmc)
1865+
v_imag = <np.ndarray>(0.5 * np.real(gmc))
18511866
if np.any(np.less(v_imag, 0)):
18521867
raise ValueError('Re(gamma - relation) < 0')
18531868
if np.any(np.not_equal(np.imag(ogamma), 0)):
@@ -1860,23 +1875,32 @@ cdef class RandomState:
18601875
if np.any(cov.flat[~idx] != 0) or np.any(np.abs(rho) > 1):
18611876
raise ValueError('Im(relation) ** 2 > Re(gamma ** 2 - relation ** 2)')
18621877

1863-
if size is None:
1864-
size = np.broadcast(loc, gpc).shape
1865-
elif np.PyArray_IsAnyScalar(size):
1866-
size = (size,)
1878+
if size is not None:
1879+
randoms = <np.ndarray>np.empty(size, np.complex128)
18671880
else:
1868-
size = tuple(size)
1881+
it = np.PyArray_MultiIterNew4(oloc, v_real, v_imag, rho)
1882+
randoms = <np.ndarray>np.empty(it.shape, np.complex128)
18691883

1870-
norms = self.standard_normal(size + (2,), method=method)
1871-
real = norms[...,0]
1872-
imag = norms[...,1]
1884+
randoms_data = <double *>np.PyArray_DATA(randoms)
1885+
n = np.PyArray_SIZE(randoms)
18731886

1874-
imag *= np.sqrt(1-rho ** 2)
1875-
imag += rho * real
1876-
real *= np.sqrt(v_real)
1877-
imag *= np.sqrt(v_imag)
1878-
1879-
return oloc + real + (0+1.0j) * imag
1887+
it = np.PyArray_MultiIterNew5(randoms, oloc, v_real, v_imag, rho)
1888+
# TODO: Box-Muller
1889+
with self.lock, nogil:
1890+
random_gauss_zig_double_fill(&self.rng_state, 2 * n, randoms_data)
1891+
with nogil:
1892+
j = 0
1893+
for i in range(n):
1894+
floc_r= (<double*>np.PyArray_MultiIter_DATA(it, 1))[0]
1895+
floc_i= (<double*>np.PyArray_MultiIter_DATA(it, 1))[1]
1896+
f_v_real = (<double*>np.PyArray_MultiIter_DATA(it, 2))[0]
1897+
f_v_imag = (<double*>np.PyArray_MultiIter_DATA(it, 3))[0]
1898+
f_rho = (<double*>np.PyArray_MultiIter_DATA(it, 4))[0]
1899+
compute_complex(&randoms_data[j], &randoms_data[j+1], floc_r, floc_i, f_v_real, f_v_imag, f_rho)
1900+
j += 2
1901+
np.PyArray_MultiIter_NEXT(it)
1902+
1903+
return randoms
18801904

18811905
def beta(self, a, b, size=None):
18821906
"""

0 commit comments

Comments
 (0)