Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit f1b6f68

Browse files
author
ariddell
committed
TST: add test for transformed data rng consistency
1 parent 1f96f14 commit f1b6f68

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from collections import OrderedDict
2+
import gc
3+
import os
4+
import tempfile
5+
import unittest
6+
7+
import numpy as np
8+
9+
import pystan
10+
from pystan._compat import PY2
11+
12+
13+
class TestGeneratedQuantitiesSeed(unittest.TestCase):
14+
"""Verify that the RNG in the transformed data block uses the overall seed.
15+
16+
See https://github.com/stan-dev/stan/issues/2241
17+
18+
"""
19+
20+
@classmethod
21+
def setUpClass(cls):
22+
model_code = """
23+
data {
24+
int<lower=0> N;
25+
}
26+
transformed data {
27+
vector[N] y;
28+
for (n in 1:N)
29+
y[n] = normal_rng(0, 1);
30+
}
31+
parameters {
32+
real mu;
33+
real<lower = 0> sigma;
34+
}
35+
model {
36+
y ~ normal(mu, sigma);
37+
}
38+
generated quantities {
39+
real mean_y = mean(y);
40+
real sd_y = sd(y);
41+
}
42+
"""
43+
cls.model = pystan.StanModel(model_code=model_code, verbose=True)
44+
45+
def test_generated_quantities_seed(self):
46+
fit1 = self.model.sampling(data={'N': 1000}, iter=10, seed=123)
47+
extr1 = fit1.extract()
48+
fit2 = self.model.sampling(data={'N': 1000}, iter=10, seed=123)
49+
extr2 = fit2.extract()
50+
self.assertTrue((extr1['mean_y'] == extr2['mean_y']).all())
51+
fit3 = self.model.sampling(data={'N': 1000}, iter=10, seed=456)
52+
extr3 = fit3.extract()
53+
self.assertFalse((extr1['mean_y'] == extr3['mean_y']).all())

0 commit comments

Comments
 (0)