Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit a489c7d

Browse files
author
ariddell
committed
BUG: Use 3-arg constructor, transformed data rng bug
Randomly generated data in the `transformed_data` block now uses the global random seed.
1 parent f1b6f68 commit a489c7d

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

pystan/model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def optimizing(self, data=None, seed=None,
468468
raise ValueError("Algorithm must be one of {}".format(algorithms))
469469
if data is None:
470470
data = {}
471-
472-
fit = self.fit_class(data)
471+
seed = pystan.misc._check_seed(seed)
472+
fit = self.fit_class(data, seed)
473473

474474
m_pars = fit._get_param_names()
475475
p_dims = fit._get_param_dims()
@@ -486,8 +486,6 @@ def optimizing(self, data=None, seed=None,
486486
not isinstance(init, string_types):
487487
raise ValueError("Wrong specification of initial values.")
488488

489-
seed = pystan.misc._check_seed(seed)
490-
491489
stan_args = dict(init=init,
492490
seed=seed,
493491
method="optim",
@@ -670,7 +668,8 @@ def sampling(self, data=None, pars=None, chains=4, iter=2000,
670668
if algorithm not in algorithms:
671669
raise ValueError("Algorithm must be one of {}".format(algorithms))
672670

673-
fit = self.fit_class(data)
671+
seed = pystan.misc._check_seed(seed)
672+
fit = self.fit_class(data, seed)
674673

675674
m_pars = fit._get_param_names()
676675
p_dims = fit._get_param_dims()
@@ -848,7 +847,8 @@ def vb(self, data=None, pars=None, iter=10000,
848847
algorithm = "meanfield" if algorithm is None else algorithm
849848
if algorithm not in algorithms:
850849
raise ValueError("Algorithm must be one of {}".format(algorithms))
851-
fit = self.fit_class(data)
850+
seed = pystan.misc._check_seed(seed)
851+
fit = self.fit_class(data, seed)
852852
m_pars = fit._get_param_names()
853853
p_dims = fit._get_param_dims()
854854

@@ -860,8 +860,6 @@ def vb(self, data=None, pars=None, iter=10000,
860860
not isinstance(init, string_types):
861861
raise ValueError("Wrong specification of initial values.")
862862

863-
seed = pystan.misc._check_seed(seed)
864-
865863
stan_args = dict(iter=iter,
866864
init=init,
867865
chain_id=1,

pystan/stan_fit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,9 +1371,9 @@ namespace pystan {
13711371
return true;
13721372
}
13731373

1374-
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i) :
1374+
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i, unsigned int random_seed) :
13751375
data_(vars_r, vars_i),
1376-
model_(data_, &std::cout),
1376+
model_(data_, random_seed, &std::cout),
13771377
base_rng(static_cast<boost::uint32_t>(std::time(0))),
13781378
names_(get_param_names(model_)),
13791379
dims_(get_param_dims(model_)),

pystan/stan_fit.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cdef extern from "stan_fit.hpp" namespace "pystan":
4040
VARIATIONAL = 4
4141

4242
cdef cppclass stan_fit[M, R]:
43-
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i) except +
43+
stan_fit(vars_r_t& vars_r, vars_i_t& vars_i, uint random_seed) except +
4444
bool update_param_oi(vector[string] pars)
4545
vector[double] unconstrain_pars(vars_r_t& vars_r, vars_i_t& vars_i)
4646
vector[double] constrain_pars(vector[double]& params_r) except +

pystan/stanfit4model.pyx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _call_sampler(data, args, pars_oi=None):
388388
cdef stan_fit[$model_cppname, ecuyer1988] *fitptr
389389
cdef vars_r_t vars_r = _dict_to_vars_r(data_r)
390390
cdef vars_i_t vars_i = _dict_to_vars_i(data_i)
391-
fitptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i)
391+
fitptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i, argsptr.random_seed)
392392
if not fitptr:
393393
raise MemoryError("Couldn't allocate space for stan_fit.")
394394
# Implementation note: there is an extra stan_fit instance associated
@@ -434,6 +434,7 @@ cdef class StanFit4Model:
434434

435435
# attributes populated by methods of StanModel
436436
cdef public data # dict or OrderedDict
437+
cdef public random_seed
437438
cdef public dict sim
438439
cdef public model_name
439440
cdef public model_pars
@@ -448,18 +449,20 @@ cdef class StanFit4Model:
448449
# __cinit__ must be callable with no arguments for unpickling
449450
cdef vars_r_t vars_r
450451
cdef vars_i_t vars_i
451-
if len(args) == 1:
452-
data = args[0]
452+
if len(args) == 2:
453+
data, random_seed = args
453454
data_r, data_i = pystan.misc._split_data(data)
454455
# NB: dictionary keys must be byte strings
455456
vars_r = _dict_to_vars_r(data_r)
456457
vars_i = _dict_to_vars_i(data_i)
457-
self.thisptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i)
458+
# TODO: the random seed needs to be known by StanFit4Model
459+
self.thisptr = new stan_fit[$model_cppname, ecuyer1988](vars_r, vars_i, <unsigned int> random_seed)
458460
if not self.thisptr:
459461
raise MemoryError("Couldn't allocate space for stan_fit.")
460462

461-
def __init__(self, data):
463+
def __init__(self, data, random_seed):
462464
self.data = data
465+
self.random_seed = random_seed
463466

464467
def __dealloc__(self):
465468
del self.thisptr
@@ -480,7 +483,7 @@ cdef class StanFit4Model:
480483
"The relevant StanModel instance must be pickled along with this fit object.\n"
481484
"When unpickling the StanModel must be unpickled first.")
482485
warnings.warn(msg)
483-
return (StanFit4Model, (self.data,), self.__getstate__(), None, None)
486+
return (StanFit4Model, (self.data, self.random_seed), self.__getstate__(), None, None)
484487

485488
# public methods
486489

0 commit comments

Comments
 (0)