Skip to content

Commit 56df03d

Browse files
committed
Merge pull request #32 from bashtage/mrg32k3a-perf
Mrg32k3a perf
2 parents 8a8eeea + 10950b4 commit 56df03d

File tree

5 files changed

+79
-62
lines changed

5 files changed

+79
-62
lines changed

randomstate/interface.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ cdef class RandomState:
499499

500500
if state['name'] != rng_name:
501501
raise ValueError('Not a ' + rng_name + ' RNG state')
502-
print(state['state'])
502+
503503
_set_state(&self.rng_state, state['state'])
504504
self.rng_state.has_gauss = state['gauss']['has_gauss']
505505
self.rng_state.gauss = state['gauss']['gauss']

randomstate/shims/mrg32k3a/mrg32k3a-shim.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ void set_seed(aug_state* state, uint64_t val)
3232

3333
void init_state(aug_state* state, int64_t vals[6])
3434
{
35-
state->rng->s10 = vals[0];
36-
state->rng->s11 = vals[1];
37-
state->rng->s12 = vals[2];
38-
state->rng->s20 = vals[3];
39-
state->rng->s21 = vals[4];
40-
state->rng->s22 = vals[5];
35+
state->rng->s1[0] = vals[0];
36+
state->rng->s1[1] = vals[1];
37+
state->rng->s1[2] = vals[2];
38+
state->rng->s2[0] = vals[3];
39+
state->rng->s2[1] = vals[4];
40+
state->rng->s2[2] = vals[5];
41+
state->rng->loc = 2;
4142
}
4243

randomstate/shims/mrg32k3a/mrg32k3a.pxi

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@ DEF RS_RNG_JUMPABLE = 1
44
cdef extern from "distributions.h":
55

66
cdef struct s_mrg32k3a_state:
7-
int64_t s10
8-
int64_t s11
9-
int64_t s12
10-
int64_t s20
11-
int64_t s21
12-
int64_t s22
7+
int64_t s1[3]
8+
int64_t s2[3]
9+
int loc
1310

1411
ctypedef s_mrg32k3a_state mrg32k3a_state
1512

@@ -31,16 +28,18 @@ ctypedef mrg32k3a_state rng_t
3128
ctypedef uint64_t rng_state_t
3229

3330
cdef object _get_state(aug_state state):
34-
return (state.rng.s10, state.rng.s11, state.rng.s12,
35-
state.rng.s20, state.rng.s21, state.rng.s22)
31+
return (state.rng.s1[0], state.rng.s1[1], state.rng.s1[2],
32+
state.rng.s2[0], state.rng.s2[1], state.rng.s2[2],
33+
state.rng.loc)
3634

3735
cdef object _set_state(aug_state *state, object state_info):
38-
state.rng.s10 = state_info[0]
39-
state.rng.s11 = state_info[1]
40-
state.rng.s12 = state_info[2]
41-
state.rng.s20 = state_info[3]
42-
state.rng.s21 = state_info[4]
43-
state.rng.s22 = state_info[5]
36+
state.rng.s1[0] = state_info[0]
37+
state.rng.s1[1] = state_info[1]
38+
state.rng.s1[2] = state_info[2]
39+
state.rng.s2[0] = state_info[3]
40+
state.rng.s2[1] = state_info[4]
41+
state.rng.s2[2] = state_info[5]
42+
state.rng.loc = state_info[6]
4443

4544
cdef object matrix_power_127(x, m):
4645
n = x.shape[0]
@@ -68,21 +67,39 @@ A2_127 = matrix_power_127(A2p, m2)
6867

6968
cdef void jump_state(aug_state* state):
7069
# vectors s1 and s2
71-
s1 = np.array([state.rng.s10,state.rng.s11,state.rng.s12], dtype=np.uint64)
72-
s2 = np.array([state.rng.s20,state.rng.s21,state.rng.s22], dtype=np.uint64)
70+
loc = state.rng.loc
71+
72+
if loc == 0:
73+
loc_m1 = 2
74+
loc_m2 = 1
75+
elif loc == 1:
76+
loc_m1 = 0
77+
loc_m2 = 2
78+
else:
79+
loc_m1 = 1
80+
loc_m2 = 0
81+
82+
s1 = np.array([state.rng.s1[loc_m2],
83+
state.rng.s1[loc_m1],
84+
state.rng.s1[loc]], dtype=np.uint64)
85+
s2 = np.array([state.rng.s2[loc_m2],
86+
state.rng.s2[loc_m1],
87+
state.rng.s2[loc]], dtype=np.uint64)
7388

7489
# Advance the state
7590
s1 = np.mod(A1_127.dot(s1), m1)
7691
s2 = np.mod(A1_127.dot(s2), m2)
7792

7893
# Restore state
79-
state.rng.s10 = s1[0]
80-
state.rng.s11 = s1[1]
81-
state.rng.s12 = s1[2]
94+
state.rng.s1[0] = s1[0]
95+
state.rng.s1[1] = s1[1]
96+
state.rng.s1[2] = s1[2]
8297

83-
state.rng.s20 = s2[0]
84-
state.rng.s21 = s2[1]
85-
state.rng.s22 = s2[2]
98+
state.rng.s2[0] = s2[0]
99+
state.rng.s2[1] = s2[1]
100+
state.rng.s2[2] = s2[2]
101+
102+
state.rng.loc = 2
86103

87104
DEF CLASS_DOCSTRING = """
88105
RandomState(seed=None)

randomstate/src/mrg32k3a/mrg32k3a.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ void mrg32k3a_seed(mrg32k3a_state* state, uint64_t seed)
3838
seeds[i] = draw;
3939
}
4040

41-
state->s10 = seeds[0];
42-
state->s11 = seeds[1];
43-
state->s12 = seeds[2];
44-
state->s20 = seeds[3];
45-
state->s21 = seeds[4];
46-
state->s22 = seeds[5];
41+
state->s1[0] = seeds[0];
42+
state->s1[1] = seeds[1];
43+
state->s1[2] = seeds[2];
44+
state->s2[0] = seeds[3];
45+
state->s2[1] = seeds[4];
46+
state->s2[2] = seeds[5];
47+
state->loc = 2;
4748
}

randomstate/src/mrg32k3a/mrg32k3a.h

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,41 @@
1414

1515
typedef struct s_mrg32k3a_state
1616
{
17-
int64_t s10;
18-
int64_t s11;
19-
int64_t s12;
20-
int64_t s20;
21-
int64_t s21;
22-
int64_t s22;
17+
int64_t s1[3];
18+
int64_t s2[3];
19+
int loc;
2320
} mrg32k3a_state;
2421

2522
inline uint32_t mrg32k3a_random(mrg32k3a_state* state)
2623
{
27-
int64_t k;
2824
int64_t p1, p2;
2925
/* Component 1 */
30-
p1 = a12 * state->s11 - a13n * state->s10;
31-
k = p1 / m1;
32-
p1 -= k * m1;
33-
if (p1 < 0)
34-
p1 += m1;
35-
state->s10 = state->s11;
36-
state->s11 = state->s12;
37-
state->s12 = p1;
26+
switch (state->loc) {
27+
case 0:
28+
p1 = a12 * state->s1[2] - a13n * state->s1[1];
29+
p2 = a21 * state->s2[0] - a23n * state->s2[1];
30+
state->loc = 1;
31+
break;
32+
case 1:
33+
p1 = a12 * state->s1[0] - a13n * state->s1[2];
34+
p2 = a21 * state->s2[1] - a23n * state->s2[2];
35+
state->loc = 2;
36+
break;
37+
case 2:
38+
p1 = a12 * state->s1[1] - a13n * state->s1[0];
39+
p2 = a21 * state->s2[2] - a23n * state->s2[0];
40+
state->loc = 0;
41+
break;
42+
}
3843

44+
p1 -= (p1 >= 0) ? (p1 / m1) * m1 : (p1 / m1) * m1 - m1;
45+
state->s1[state->loc] = p1;
3946
/* Component 2 */
40-
p2 = a21 * state->s22 - a23n * state->s20;
41-
k = p2 / m2;
42-
p2 -= k * m2;
43-
if (p2 < 0)
44-
p2 += m2;
45-
state->s20 = state->s21;
46-
state->s21 = state->s22;
47-
state->s22 = p2;
47+
p2 -= (p2 >= 0) ? (p2 / m2) * m2 : (p2 / m2) * m2 - m2;
48+
state->s2[state->loc] = p2;
4849

4950
/* Combination */
50-
if (p1 <= p2)
51-
return (uint32_t)(p1 - p2 + m1);
52-
else
53-
return (uint32_t)(p1 - p2);
51+
return (uint32_t)((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2));
5452
}
5553

5654
void mrg32k3a_seed(mrg32k3a_state* state, uint64_t seed);

0 commit comments

Comments
 (0)