This repository was archived by the owner on Mar 19, 2021. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +53
-0
lines changed Expand file tree Collapse file tree 1 file changed +53
-0
lines changed Original file line number Diff line number Diff line change
1
+ from collections import OrderedDict
2
+ import gc
3
+ import os
4
+ import tempfile
5
+ import unittest
6
+
7
+ import numpy as np
8
+
9
+ import pystan
10
+ from pystan ._compat import PY2
11
+
12
+
13
+ class TestGeneratedQuantitiesSeed (unittest .TestCase ):
14
+ """Verify that the RNG in the transformed data block uses the overall seed.
15
+
16
+ See https://github.com/stan-dev/stan/issues/2241
17
+
18
+ """
19
+
20
+ @classmethod
21
+ def setUpClass (cls ):
22
+ model_code = """
23
+ data {
24
+ int<lower=0> N;
25
+ }
26
+ transformed data {
27
+ vector[N] y;
28
+ for (n in 1:N)
29
+ y[n] = normal_rng(0, 1);
30
+ }
31
+ parameters {
32
+ real mu;
33
+ real<lower = 0> sigma;
34
+ }
35
+ model {
36
+ y ~ normal(mu, sigma);
37
+ }
38
+ generated quantities {
39
+ real mean_y = mean(y);
40
+ real sd_y = sd(y);
41
+ }
42
+ """
43
+ cls .model = pystan .StanModel (model_code = model_code , verbose = True )
44
+
45
+ def test_generated_quantities_seed (self ):
46
+ fit1 = self .model .sampling (data = {'N' : 1000 }, iter = 10 , seed = 123 )
47
+ extr1 = fit1 .extract ()
48
+ fit2 = self .model .sampling (data = {'N' : 1000 }, iter = 10 , seed = 123 )
49
+ extr2 = fit2 .extract ()
50
+ self .assertTrue ((extr1 ['mean_y' ] == extr2 ['mean_y' ]).all ())
51
+ fit3 = self .model .sampling (data = {'N' : 1000 }, iter = 10 , seed = 456 )
52
+ extr3 = fit3 .extract ()
53
+ self .assertFalse ((extr1 ['mean_y' ] == extr3 ['mean_y' ]).all ())
You can’t perform that action at this time.
0 commit comments