Skip to content

Commit a6d3791

Browse files
committed
REF: Reorder normal draws in complex normal
Refactor complex normal so that a call with multiple values returned returns the same values as multiple calls with a single rv
1 parent 080a9a0 commit a6d3791

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

randomstate/randomstate.pyx

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,9 +1803,15 @@ cdef class RandomState:
18031803
f_imag *= sqrt(0.5 * f_v_imag)
18041804

18051805
return PyComplex_FromDoubles(f_real, f_imag)
1806-
1807-
real = self.standard_normal(size=size, method=method)
1808-
imag = self.standard_normal(size=size, method=method)
1806+
1807+
if np.PyArray_IsAnyScalar(size):
1808+
size = (size,)
1809+
else:
1810+
size = tuple(size)
1811+
1812+
norms = self.standard_normal(size=size + (2,), method=method)
1813+
real = norms[...,0]
1814+
imag = norms[...,1]
18091815

18101816
imag *= sqrt(1 - f_rho * f_rho)
18111817
imag += f_rho * real
@@ -1834,8 +1840,15 @@ cdef class RandomState:
18341840

18351841
if size is None:
18361842
size = np.broadcast(loc, gpc).shape
1837-
real = self.standard_normal(size, method=method)
1838-
imag = self.standard_normal(size, method=method)
1843+
elif np.PyArray_IsAnyScalar(size):
1844+
size = (size,)
1845+
else:
1846+
size = tuple(size)
1847+
1848+
norms = self.standard_normal(size + (2,), method=method)
1849+
real = norms[...,0]
1850+
imag = norms[...,1]
1851+
18391852
imag *= np.sqrt(1-rho ** 2)
18401853
imag += rho * real
18411854
real *= np.sqrt(v_real)

0 commit comments

Comments
 (0)