Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit 8330cef

Browse files
authored
Merge pull request #355 from ariddell/feature/issue-341-transformed-data-rng
Use random seed in transformed data rng
2 parents 6ed0de9 + a489c7d commit 8330cef

File tree

5 files changed

+71
-17
lines changed

5 files changed

+71
-17
lines changed

pystan/model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def optimizing(self, data=None, seed=None,
468468
raise ValueError("Algorithm must be one of {}".format(algorithms))
469469
if data is None:
470470
data = {}
471-
472-
fit = self.fit_class(data)
471+
seed = pystan.misc._check_seed(seed)
472+
fit = self.fit_class(data, seed)
473473

474474
m_pars = fit._get_param_names()
475475
p_dims = fit._get_param_dims()
@@ -486,8 +486,6 @@ def optimizing(self, data=None, seed=None,
486486
not isinstance(init, string_types):
487487
raise ValueError("Wrong specification of initial values.")
488488

489-
seed = pystan.misc._check_seed(seed)
490-
491489
stan_args = dict(init=init,
492490
seed=seed,
493491
method="optim",
@@ -670,7 +668,8 @@ def sampling(self, data=None, pars=None, chains=4, iter=2000,
670668
if algorithm not in algorithms:
671669
raise ValueError("Algorithm must be one of {}".format(algorithms))
672670

673-
fit = self.fit_class(data)
671+
seed = pystan.misc._check_seed(seed)
672+
fit = self.fit_class(data, seed)
674673

675674
m_pars = fit._get_param_names()
676675
p_dims = fit._get_param_dims()
@@ -848,7 +847,8 @@ def vb(self, data=None, pars=None, iter=10000,
848847
algorithm = "meanfield" if algorithm is None else algorithm
849848
if algorithm not in algorithms:
850849
raise ValueError("Algorithm must be one of {}".format(algorithms))
851-
fit = self.fit_class(data)
850+
seed = pystan.misc._check_seed(seed)
851+
fit = self.fit_class(data, seed)
852852
m_pars = fit._get_param_names()
853853
p_dims = fit._get_param_dims()
854854

@@ -860,8 +860,6 @@ def vb(self, data=None, pars=None, iter=10000,
860860
not isinstance(init, string_types):
861861
raise ValueError("Wrong specification of initial values.")
862862

863-
seed = pystan.misc._check_seed(seed)
864-
865863
stan_args = dict(iter=iter,
866864
init=init,
867865
chain_id=1,

pystan/stan_fit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,9 +1371,9 @@ namespace pystan {
13711371
return true;
13721372
}
13731373

1374-
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i) :
1374+
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i, unsigned int random_seed) :
13751375
data_(vars_r, vars_i),
1376-
model_(data_, &std::cout),
1376+
model_(data_, random_seed, &std::cout),
13771377
base_rng(static_cast<boost::uint32_t>(std::time(0))),
13781378
names_(get_param_names(model_)),
13791379
dims_(get_param_dims(model_)),

pystan/stan_fit.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cdef extern from "stan_fit.hpp" namespace "pystan":
4040
VARIATIONAL = 4
4141

4242
cdef cppclass stan_fit[M, R]:
43-
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i) except +
43+
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i, uint random_seed) except +
4444
bool update_param_oi(vector[string] pars)
4545
vector[double] unconstrain_pars(vars_r_t& vars_r, vars_i_t& vars_i)
4646
vector[double] constrain_pars(vector[double]& params_r) except +

pystan/stanfit4model.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _call_sampler(data, args, pars_oi=None):
388388
cdef stan_fit[$model_cppname, ecuyer1988] *fitptr
389389
cdef vars_r_t vars_r = _dict_to_vars_r(data_r)
390390
cdef vars_i_t vars_i = _dict_to_vars_i(data_i)
391-
fitptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i)
391+
fitptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i, argsptr.random_seed)
392392
if not fitptr:
393393
raise MemoryError("Couldn't allocate space for stan_fit.")
394394
# Implementation note: there is an extra stan_fit instance associated
@@ -434,6 +434,7 @@ cdef class StanFit4Model:
434434

435435
# attributes populated by methods of StanModel
436436
cdef public data # dict or OrderedDict
437+
cdef public random_seed
437438
cdef public dict sim
438439
cdef public model_name
439440
cdef public model_pars
@@ -448,18 +449,20 @@ cdef class StanFit4Model:
448449
# __cinit__ must be callable with no arguments for unpickling
449450
cdef vars_r_t vars_r
450451
cdef vars_i_t vars_i
451-
if len(args) == 1:
452-
data = args[0]
452+
if len(args) == 2:
453+
data, random_seed = args
453454
data_r, data_i = pystan.misc._split_data(data)
454455
# NB: dictionary keys must be byte strings
455456
vars_r = _dict_to_vars_r(data_r)
456457
vars_i = _dict_to_vars_i(data_i)
457-
self.thisptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i)
458+
# TODO: the random seed needs to be known by StanFit4Model
459+
self.thisptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i, <unsigned int> random_seed)
458460
if not self.thisptr:
459461
raise MemoryError("Couldn't allocate space for stan_fit.")
460462

461-
def __init__(self, data):
463+
def __init__(self, data, random_seed):
462464
self.data = data
465+
self.random_seed = random_seed
463466

464467
def __dealloc__(self):
465468
del self.thisptr
@@ -480,7 +483,7 @@ cdef class StanFit4Model:
480483
"The relevant StanModel instance must be pickled along with this fit object.\n"
481484
"When unpickling the StanModel must be unpickled first.")
482485
warnings.warn(msg)
483-
return (StanFit4Model, (self.data,), self.__getstate__(), None, None)
486+
return (StanFit4Model, (self.data, self.random_seed), self.__getstate__(), None, None)
484487

485488
# public methods
486489

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from collections import OrderedDict
2+
import gc
3+
import os
4+
import tempfile
5+
import unittest
6+
7+
import numpy as np
8+
9+
import pystan
10+
from pystan._compat import PY2
11+
12+
13+
class TestGeneratedQuantitiesSeed(unittest.TestCase):
14+
"""Verify that the RNG in the transformed data block uses the overall seed.
15+
16+
See https://github.com/stan-dev/stan/issues/2241
17+
18+
"""
19+
20+
@classmethod
21+
def setUpClass(cls):
22+
model_code = """
23+
data {
24+
int<lower=0> N;
25+
}
26+
transformed data {
27+
vector[N] y;
28+
for (n in 1:N)
29+
y[n] = normal_rng(0, 1);
30+
}
31+
parameters {
32+
real mu;
33+
real<lower = 0> sigma;
34+
}
35+
model {
36+
y ~ normal(mu, sigma);
37+
}
38+
generated quantities {
39+
real mean_y = mean(y);
40+
real sd_y = sd(y);
41+
}
42+
"""
43+
cls.model = pystan.StanModel(model_code=model_code, verbose=True)
44+
45+
def test_generated_quantities_seed(self):
46+
fit1 = self.model.sampling(data={'N': 1000}, iter=10, seed=123)
47+
extr1 = fit1.extract()
48+
fit2 = self.model.sampling(data={'N': 1000}, iter=10, seed=123)
49+
extr2 = fit2.extract()
50+
self.assertTrue((extr1['mean_y'] == extr2['mean_y']).all())
51+
fit3 = self.model.sampling(data={'N': 1000}, iter=10, seed=456)
52+
extr3 = fit3.extract()
53+
self.assertFalse((extr1['mean_y'] == extr3['mean_y']).all())

0 commit comments

Comments
 (0)