Skip to content

Commit d63103e

Browse files
committed
Merge pull request #23 from bashtage/add-aligned-malloc
ENH: Add aligned malloc
2 parents dcffce9 + 48e2e8b commit d63103e

File tree

16 files changed

+139
-44
lines changed

16 files changed

+139
-44
lines changed

randomstate/aligned_malloc.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include "aligned_malloc.h"
2+
3+
static NPY_INLINE void *PyArray_realloc_aligned(void *p, size_t n);
4+
5+
static NPY_INLINE void *PyArray_malloc_aligned(size_t n);
6+
7+
static NPY_INLINE void *PyArray_calloc_aligned(size_t n, size_t s);
8+
9+
static NPY_INLINE void PyArray_free_aligned(void *p);

randomstate/aligned_malloc.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "Python.h"
2+
#include "numpy/npy_common.h"
3+
4+
#define NPY_MEMALIGN 16 /* 16 for SSE2, 32 for AVX, 64 for Xeon Phi */
5+
6+
static NPY_INLINE
7+
void *PyArray_realloc_aligned(void *p, size_t n)
8+
{
9+
void *p1, **p2, *base;
10+
size_t old_offs, offs = NPY_MEMALIGN - 1 + sizeof(void*);
11+
if (NPY_UNLIKELY(p != NULL)) {
12+
base = *(((void**)p)-1);
13+
if (NPY_UNLIKELY((p1 = PyMem_Realloc(base,n+offs)) == NULL)) return NULL;
14+
if (NPY_LIKELY(p1 == base)) return p;
15+
p2 = (void**)(((Py_uintptr_t)(p1)+offs) & ~(NPY_MEMALIGN-1));
16+
old_offs = (size_t)((Py_uintptr_t)p - (Py_uintptr_t)base);
17+
/* TODO: This isn't right, removed void* to aloow msvc to do pointer math */
18+
memmove((void*)p2,((char*)p1)+old_offs,n);
19+
} else {
20+
if (NPY_UNLIKELY((p1 = PyMem_Malloc(n + offs)) == NULL)) return NULL;
21+
p2 = (void**)(((Py_uintptr_t)(p1)+offs) & ~(NPY_MEMALIGN-1));
22+
}
23+
*(p2-1) = p1;
24+
return (void*)p2;
25+
}
26+
27+
static NPY_INLINE
28+
void *PyArray_malloc_aligned(size_t n)
29+
{
30+
return PyArray_realloc_aligned(NULL, n);
31+
}
32+
33+
static NPY_INLINE
34+
void *PyArray_calloc_aligned(size_t n, size_t s)
35+
{
36+
void *p;
37+
if (NPY_UNLIKELY((p = PyArray_realloc_aligned(NULL,n*s)) == NULL)) return NULL;
38+
memset(p, 0, n*s);
39+
return p;
40+
}
41+
42+
static NPY_INLINE
43+
void PyArray_free_aligned(void *p)
44+
{
45+
void *base = *(((void**)p)-1);
46+
PyMem_Free(base);
47+
}

randomstate/aligned_malloc.pxi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
cdef extern from "aligned_malloc.h":
2+
cdef void *PyArray_realloc_aligned(void *p, size_t n);
3+
cdef void *PyArray_malloc_aligned(size_t n);
4+
cdef void *PyArray_calloc_aligned(size_t n, size_t s);
5+
cdef void PyArray_free_aligned(void *p);

randomstate/interface.pyx

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ cdef extern from "distributions.h":
112112
cdef void random_gauss_fill(aug_state* state, int count, double *out) nogil
113113
cdef void random_gauss_zig_julia_fill(aug_state* state, int count, double *out) nogil
114114

115-
116115
include "array_utilities.pxi"
117116
include "bounded_integers.pxi"
117+
include "aligned_malloc.pxi"
118118

119119
cdef double kahan_sum(double *darr, np.npy_intp n):
120120
cdef double c, y, t, sum
@@ -141,38 +141,25 @@ cdef class RandomState:
141141

142142
IF RNG_SEED==1:
143143
def __init__(self, seed=None):
144-
IF RNG_MOD_NAME == 'dsfmt':
145-
cdef int8_t *iptr
146-
cdef int8_t offset = 0
147-
cdef intptr_t alignment = 0
148-
self.rng_loc = PyMem_Malloc(sizeof(rng_t))
149-
self.rng_state.rng = <rng_t *>self.rng_loc
150-
alignment = <intptr_t>(&(self.rng_state.rng.status[0].u32[0]))
151-
if (alignment % 16) != 0:
152-
iptr = <int8_t *>self.rng_state.rng
153-
offset = 16 - (alignment % 16)
154-
if offset < 0:
155-
offset += 16
156-
self.rng_state.rng = <rng_t *>(iptr + offset)
157-
ELSE:
158-
self.rng_loc = PyMem_Malloc(sizeof(rng_t))
159-
self.rng_state.rng = <rng_t *>self.rng_loc
160-
144+
self.rng_state.rng = <rng_t *>PyArray_malloc_aligned(sizeof(rng_t))
161145
self.rng_state.binomial = &self.binomial_info
162-
self._reset_state_variables()
146+
IF RNG_MOD_NAME == 'dsfmt':
147+
self.rng_state.buffered_uniforms = <double *>PyArray_malloc_aligned(2 * DSFMT_N * sizeof(double))
163148
self.lock = Lock()
149+
self._reset_state_variables()
164150
self.seed(seed)
165-
166151
ELSE:
167152
def __init__(self, seed=None, inc=None):
168-
self.rng_state.rng = <rng_t *>PyMem_Malloc(sizeof(rng_t)) # &self.rng
153+
self.rng_state.rng = <rng_t *>PyArray_malloc_aligned(sizeof(rng_t))
169154
self.rng_state.binomial = &self.binomial_info
170-
self._reset_state_variables()
171155
self.lock = Lock()
156+
self._reset_state_variables()
172157
self.seed(seed, inc)
173158

174159
def __dealloc__(self):
175-
PyMem_Free(self.rng_loc)
160+
PyArray_free_aligned(self.rng_state.rng)
161+
IF RNG_MOD_NAME == 'dsfmt':
162+
PyArray_free_aligned(self.rng_state.buffered_uniforms)
176163

177164
# Pickling support:
178165
def __getstate__(self):
@@ -496,7 +483,7 @@ cdef class RandomState:
496483
if isinstance(state, tuple):
497484
if state[0] != 'MT19937':
498485
raise ValueError('Not a ' + rng_name + ' RNG state')
499-
_set_state(self.rng_state, (state[1], state[2]))
486+
_set_state(&self.rng_state, (state[1], state[2]))
500487
self.rng_state.has_gauss = state[3]
501488
self.rng_state.gauss = state[4]
502489
self.rng_state.has_uint32 = 0
@@ -505,7 +492,8 @@ cdef class RandomState:
505492

506493
if state['name'] != rng_name:
507494
raise ValueError('Not a ' + rng_name + ' RNG state')
508-
_set_state(self.rng_state, state['state'])
495+
print(state['state'])
496+
_set_state(&self.rng_state, state['state'])
509497
self.rng_state.has_gauss = state['gauss']['has_gauss']
510498
self.rng_state.gauss = state['gauss']['gauss']
511499
self.rng_state.has_uint32 = state['uint32']['has_uint32']

randomstate/shims/dSFMT/dSFMT-shim.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,25 @@ extern inline uint64_t random_uint64(aug_state* state);
66

77
extern inline double random_double(aug_state* state);
88

9+
void reset_buffer(aug_state* state)
10+
{
11+
int i = 0;
12+
for (i = 0; i < (2 * DSFMT_N); i++)
13+
{
14+
state->buffered_uniforms[i] = 0.0;
15+
}
16+
state->buffer_loc = 2 * DSFMT_N;
17+
}
18+
919
extern void set_seed_by_array(aug_state* state, uint32_t init_key[], int key_length)
1020
{
21+
reset_buffer(state);
1122
dsfmt_init_by_array(state->rng, init_key, key_length);
1223
}
1324

1425
void set_seed(aug_state* state, uint32_t seed)
1526
{
27+
reset_buffer(state);
1628
dsfmt_init_gen_rand(state->rng, seed);
1729
}
1830

randomstate/shims/dSFMT/dSFMT-shim.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,48 @@ typedef struct s_aug_state {
1919
double gauss;
2020
uint32_t uinteger;
2121
uint64_t zig_random_int;
22+
23+
double *buffered_uniforms;
24+
int buffer_loc;
2225
} aug_state;
2326

27+
static inline double random_double_from_buffer(aug_state *state)
28+
{
29+
double out;
30+
if (state->buffer_loc >= (2 * DSFMT_N))
31+
{
32+
state->buffer_loc = 0;
33+
dsfmt_fill_array_close1_open2(state->rng, state->buffered_uniforms, 2 * DSFMT_N);
34+
}
35+
out = state->buffered_uniforms[state->buffer_loc];
36+
state->buffer_loc++;
37+
return out;
38+
}
39+
2440
static inline uint32_t random_uint32(aug_state* state)
2541
{
26-
double d = dsfmt_genrand_close1_open2(state->rng);
42+
double d = random_double_from_buffer(state);//dsfmt_genrand_close1_open2(state->rng);
2743
uint64_t *out = (uint64_t *)&d;
2844
return (uint32_t)(*out & 0xffffffff);
2945
}
3046

3147
static inline uint64_t random_uint64(aug_state* state)
3248
{
33-
double d = dsfmt_genrand_close1_open2(state->rng);
49+
double d = random_double_from_buffer(state);//dsfmt_genrand_close1_open2(state->rng);
3450
uint64_t out;
3551
uint64_t *tmp;
3652
tmp = (uint64_t *)&d;
3753
out = *tmp << 32;
38-
d = dsfmt_genrand_close1_open2(state->rng);
54+
d = random_double_from_buffer(state);//dsfmt_genrand_close1_open2(state->rng);
3955
tmp = (uint64_t *)&d;
4056
out |= *tmp & 0xffffffff;
4157
return out;
4258
}
4359

4460
static inline double random_double(aug_state* state)
4561
{
46-
return dsfmt_genrand_close1_open2(state->rng) - 1.0;
62+
return random_double_from_buffer(state) - 1.0;
63+
// return dsfmt_genrand_close1_open2(state->rng) - 1.0;
4764
}
4865

4966
extern void entropy_init(aug_state* state);

randomstate/shims/dSFMT/dSFMT.pxi

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ cdef extern from "distributions.h":
3434
uint64_t zig_random_int
3535
uint32_t uinteger
3636

37+
double *buffered_uniforms
38+
int buffer_loc
39+
3740
ctypedef s_aug_state aug_state
3841

3942
cdef void set_seed(aug_state* state, uint32_t seed)
@@ -43,25 +46,38 @@ cdef extern from "distributions.h":
4346
ctypedef dsfmt_t rng_t
4447

4548
cdef object _get_state(aug_state state):
46-
cdef uint32_t [:] key = np.zeros(4 * DSFMT_N_PLUS_1, dtype=np.uint32)
49+
cdef uint32_t [::1] key = np.zeros(4 * DSFMT_N_PLUS_1, dtype=np.uint32)
50+
cdef double [::1] buf = np.zeros(2 * DSFMT_N, dtype=np.double)
4751
cdef Py_ssize_t i, j, key_loc = 0
4852
cdef w128_t state_val
4953
for i in range(DSFMT_N_PLUS_1):
5054
state_val = state.rng.status[i]
5155
for j in range(4):
5256
key[key_loc] = state_val.u32[j]
5357
key_loc += 1
54-
return (np.asarray(key), state.rng.idx)
58+
for i in range(2 * DSFMT_N):
59+
buf[i] = state.buffered_uniforms[i]
60+
61+
return (np.asarray(key), state.rng.idx,
62+
np.asarray(buf), state.buffer_loc)
5563

56-
cdef object _set_state(aug_state state, object state_info):
57-
cdef uint32_t [:] key = state_info[0]
64+
cdef object _set_state(aug_state *state, object state_info):
5865
cdef Py_ssize_t i, j, key_loc = 0
66+
cdef uint32_t [::1] key = state_info[0]
67+
state.rng.idx = state_info[1]
68+
69+
5970
for i in range(DSFMT_N_PLUS_1):
6071
for j in range(4):
6172
state.rng.status[i].u32[j] = key[key_loc]
6273
key_loc += 1
6374

64-
state.rng.idx = state_info[1]
75+
state.buffer_loc = <int>state_info[3]
76+
for i in range(2 * DSFMT_N):
77+
state.buffered_uniforms[i] = state_info[2][i]
78+
79+
80+
6581

6682
DEF CLASS_DOCSTRING = """
6783
RandomState(seed=None)

randomstate/shims/mlfg-1279-861/mlfg-1279-861.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ cdef object _get_state(aug_state state):
3939
key[i] = state.rng.lags[i]
4040
return (np.asanyarray(key), state.rng.pos, state.rng.lag_pos)
4141

42-
cdef object _set_state(aug_state state, object state_info):
42+
cdef object _set_state(aug_state *state, object state_info):
4343
cdef uint64_t [:] key = state_info[0]
4444
cdef Py_ssize_t i
4545
for i in range(MLFG_STATE_LEN):

randomstate/shims/mrg32k3a/mrg32k3a.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ cdef object _get_state(aug_state state):
3838
return (state.rng.s10, state.rng.s11, state.rng.s12,
3939
state.rng.s20, state.rng.s21, state.rng.s22)
4040

41-
cdef object _set_state(aug_state state, object state_info):
41+
cdef object _set_state(aug_state *state, object state_info):
4242
state.rng.s10 = state_info[0]
4343
state.rng.s11 = state_info[1]
4444
state.rng.s12 = state_info[2]

randomstate/shims/pcg-32/pcg-32.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ ctypedef pcg32_random_t rng_t
3434
cdef object _get_state(aug_state state):
3535
return (state.rng.state, state.rng.inc)
3636

37-
cdef object _set_state(aug_state state, object state_info):
37+
cdef object _set_state(aug_state *state, object state_info):
3838
state.rng.state = state_info[0]
3939
state.rng.inc = state_info[1]
4040

0 commit comments

Comments
 (0)