Skip to content

Commit 3a196d4

Browse files
authored
Merge pull request #3199 from stan-dev/fix/3198-cholesky-decompose-templating
Update templating inside cholesky_decompose
2 parents d959c81 + 2157520 commit 3a196d4

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

stan/math/prim/fun/cholesky_decompose.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,11 @@ namespace math {
2828
*/
2929
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
3030
require_not_eigen_vt<is_var, EigMat>* = nullptr>
31-
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
32-
EigMat::ColsAtCompileTime>
33-
cholesky_decompose(const EigMat& m) {
34-
const eval_return_type_t<EigMat>& m_eval = m.eval();
31+
inline plain_type_t<EigMat> cholesky_decompose(const EigMat& m) {
32+
auto&& m_eval = to_ref(m);
3533
check_symmetric("cholesky_decompose", "m", m_eval);
3634
check_not_nan("cholesky_decompose", "m", m_eval);
37-
Eigen::LLT<Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
38-
EigMat::ColsAtCompileTime>>
39-
llt = m_eval.llt();
35+
Eigen::LLT<plain_type_t<EigMat>> llt = m_eval.llt();
4036
check_pos_definite("cholesky_decompose", "m", llt);
4137
return llt.matrixL();
4238
}

test/unit/math/prim/fun/cholesky_decompose_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,18 @@ TEST(MathMatrixPrimMat, cholesky_decompose_exception) {
3535
EXPECT_THROW_MSG(stan::math::cholesky_decompose(m), std::domain_error,
3636
"is not symmetric");
3737
}
38+
39+
TEST(MathMatrixPrimMat, cholesky_decompose_expressions) {
40+
// Test for https://github.com/stan-dev/math/issues/3198
41+
stan::math::matrix_d A(2, 3);
42+
A << 1, 2, 3, 4, 5, 6;
43+
44+
stan::math::vector_d L_u(3);
45+
L_u << 1, 0, 0.5;
46+
47+
auto L = stan::math::cholesky_corr_constrain(L_u, 3);
48+
49+
EXPECT_NO_THROW(stan::math::cholesky_decompose(stan::math::multiply(
50+
stan::math::multiply(A, stan::math::multiply_lower_tri_self_transpose(L)),
51+
stan::math::transpose(A))));
52+
}

0 commit comments

Comments
 (0)