Skip to content

Commit 479c04b

Browse files
committed
Merge branch 'develop' into multi-normal-derivatives-2
2 parents bac7e60 + c491e9b commit 479c04b

File tree

5 files changed

+56
-15
lines changed

5 files changed

+56
-15
lines changed

.github/workflows/header_checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727

2828
steps:
2929
- uses: actions/checkout@v4
30-
- uses: actions/setup-python@v4
30+
- uses: actions/setup-python@v5
3131
with:
3232
python-version: '3.x'
3333

.github/workflows/main.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525

2626
steps:
2727
- uses: actions/checkout@v4
28-
- uses: actions/setup-python@v4
28+
- uses: actions/setup-python@v5
2929
with:
3030
python-version: '3.x'
3131
- uses: r-lib/actions/setup-r@v2
@@ -75,7 +75,7 @@ jobs:
7575

7676
steps:
7777
- uses: actions/checkout@v4
78-
- uses: actions/setup-python@v4
78+
- uses: actions/setup-python@v5
7979
with:
8080
python-version: '3.x'
8181
- uses: r-lib/actions/setup-r@v2
@@ -129,7 +129,7 @@ jobs:
129129

130130
steps:
131131
- uses: actions/checkout@v4
132-
- uses: actions/setup-python@v4
132+
- uses: actions/setup-python@v5
133133
with:
134134
python-version: '3.x'
135135
- uses: r-lib/actions/setup-r@v2
@@ -178,7 +178,7 @@ jobs:
178178

179179
steps:
180180
- uses: actions/checkout@v4
181-
- uses: actions/setup-python@v4
181+
- uses: actions/setup-python@v5
182182
with:
183183
python-version: '3.x'
184184
- uses: r-lib/actions/setup-r@v2

stan/math/prim/fun/chol2inv.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/dot_self.hpp>
77
#include <stan/math/prim/fun/dot_product.hpp>
8-
#include <stan/math/prim/fun/mdivide_left_tri_low.hpp>
8+
#include <stan/math/prim/fun/mdivide_left_tri.hpp>
99
#include <stan/math/prim/fun/inv_square.hpp>
1010

1111
namespace stan {
@@ -35,7 +35,7 @@ plain_type_t<T> chol2inv(const T& L) {
3535
X.coeffRef(0) = inv_square(L_ref.coeff(0, 0));
3636
return X;
3737
}
38-
T_result L_inv = mdivide_left_tri_low(L_ref, T_result::Identity(K, K));
38+
T_result L_inv = mdivide_left_tri<Eigen::Lower>(L_ref);
3939
T_result X(K, K);
4040
for (int k = 0; k < K; ++k) {
4141
X.coeffRef(k, k) = dot_self(L_inv.col(k).tail(K - k).eval());

stan/math/prim/prob/inv_wishart_cholesky_rng.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6-
#include <stan/math/prim/fun/mdivide_left_tri.hpp>
7-
#include <stan/math/prim/prob/wishart_cholesky_rng.hpp>
8-
#include <stan/math/prim/prob/wishart_rng.hpp>
6+
#include <stan/math/prim/fun/mdivide_left_tri_low.hpp>
97

108
namespace stan {
119
namespace math {
@@ -16,6 +14,9 @@ namespace math {
1614
* from the inverse Wishart distribution with the specified degrees of freedom
1715
* using the specified random number generator.
1816
*
17+
* Axen, Seth D. "Efficiently generating inverse-Wishart matrices and their
18+
* Cholesky factors." arXiv preprint arXiv:2310.15884 (2023).
19+
*
1920
* @tparam RNG Random number generator type
2021
* @param[in] nu scalar degrees of freedom
2122
* @param[in] L_S lower Cholesky factor of the scale matrix
@@ -38,8 +39,15 @@ inline Eigen::MatrixXd inv_wishart_cholesky_rng(double nu,
3839
check_positive(function, "Cholesky Scale matrix", L_S.diagonal());
3940
check_positive(function, "columns of Cholesky Scale matrix", L_S.cols());
4041

41-
MatrixXd L_Sinv = mdivide_left_tri<Eigen::Lower>(L_S);
42-
return mdivide_left_tri<Eigen::Lower>(wishart_cholesky_rng(nu, L_Sinv, rng));
42+
MatrixXd B = MatrixXd::Zero(k, k);
43+
for (int j = 0; j < k; ++j) {
44+
for (int i = 0; i < j; ++i) {
45+
B(j, i) = normal_rng(0, 1, rng);
46+
}
47+
B(j, j) = std::sqrt(chi_square_rng(nu - k + j + 1, rng));
48+
}
49+
50+
return mdivide_left_tri_low(B, L_S);
4351
}
4452

4553
} // namespace math

test/unit/math/prim/prob/inv_wishart_cholesky_rng_test.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,15 @@ TEST(ProbDistributionsInvWishartCholesky, SpecialRNGTest) {
9191
using stan::math::inv_wishart_cholesky_rng;
9292
using stan::math::multiply_lower_tri_self_transpose;
9393

94-
boost::random::mt19937 rng(1234U);
94+
boost::random::mt19937 rng(92343U);
9595
int N = 1e5;
9696
double tol = 0.1;
9797
for (int k = 1; k < 5; k++) {
98-
MatrixXd sigma = MatrixXd::Identity(k, k);
98+
MatrixXd L = MatrixXd::Identity(k, k);
9999
MatrixXd Z = MatrixXd::Zero(k, k);
100100
for (int i = 0; i < N; i++) {
101-
Z += stan::math::crossprod(inv_wishart_cholesky_rng(k + 2, sigma, rng));
101+
Z += multiply_lower_tri_self_transpose(
102+
inv_wishart_cholesky_rng(k + 2, L, rng));
102103
}
103104
Z /= N;
104105
for (int j = 0; j < k; j++) {
@@ -111,3 +112,35 @@ TEST(ProbDistributionsInvWishartCholesky, SpecialRNGTest) {
111112
}
112113
}
113114
}
115+
116+
TEST(ProbDistributionsInvWishartCholesky, compareToInvWishart) {
117+
// Compare the marginal mean
118+
119+
using Eigen::MatrixXd;
120+
using Eigen::VectorXd;
121+
using stan::math::inv_wishart_cholesky_rng;
122+
using stan::math::inv_wishart_rng;
123+
using stan::math::multiply_lower_tri_self_transpose;
124+
using stan::math::qr_thin_Q;
125+
126+
boost::random::mt19937 rng(92343U);
127+
int N = 1e5;
128+
double tol = 0.05;
129+
for (int k = 1; k < 5; k++) {
130+
MatrixXd L = qr_thin_Q(MatrixXd::Random(k, k)).transpose();
131+
L.diagonal() = stan::math::abs(L.diagonal());
132+
MatrixXd sigma = multiply_lower_tri_self_transpose(L);
133+
MatrixXd Z_mean = sigma / (k + 3);
134+
MatrixXd Z_est = MatrixXd::Zero(k, k);
135+
for (int i = 0; i < N; i++) {
136+
Z_est += multiply_lower_tri_self_transpose(
137+
inv_wishart_cholesky_rng(k + 4, L, rng));
138+
}
139+
Z_est /= N;
140+
for (int j = 0; j < k; j++) {
141+
for (int i = 0; i < j; i++) {
142+
EXPECT_NEAR(Z_est(i, j), Z_mean(i, j), tol);
143+
}
144+
}
145+
}
146+
}

0 commit comments

Comments
 (0)