Skip to content

Commit f7ccc01

Browse files
authored
Merge pull request #3209 from stan-dev/laplace/move-theta-to-tols
Laplace: move theta0 to tolerances
2 parents 5f85d74 + 399b713 commit f7ccc01

22 files changed

+321
-281
lines changed

doxygen/doxygen.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,7 @@ MSCGEN_TOOL =
27432743
MSCFILE_DIRS =
27442744

27452745
ALIASES += laplace_options="\
2746+
\param[in] theta_0 the initial guess for the Laplace approximation. \
27462747
\param[in] tolerance controls the convergence criterion when finding the mode in the Laplace approximation. \
27472748
\param[in] max_num_steps maximum number of steps before the Newton solver breaks and returns an error. \
27482749
\param[in] hessian_block_size Block size of Hessian of log likelihood w.r.t latent Gaussian variable theta. \
@@ -2754,7 +2755,6 @@ ALIASES += laplace_options="\
27542755
"
27552756

27562757
ALIASES += laplace_common_template_args="\
2757-
\tparam ThetaVec A type inheriting from `Eigen::EigenBase` with dynamic sized rows and 1 column. \
27582758
\tparam CovarFun A functor with an `operator()(CovarArgsElements..., {TrainTupleElements...| PredTupleElements...})` \
27592759
method. The `operator()` method should accept as arguments the \
27602760
inner elements of `CovarArgs`. The return type of the `operator()` method \
@@ -2764,7 +2764,6 @@ ALIASES += laplace_common_template_args="\
27642764
"
27652765

27662766
ALIASES += laplace_common_args="\
2767-
\param[in] theta_0 the initial guess for the Laplace approximation. \
27682767
\param[in] covariance_function a function which returns the prior covariance. \
27692768
\param[in] covar_args arguments for the covariance function. \
27702769
"

stan/math/mix/functor/laplace_base_rng.hpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,15 @@ namespace math {
3131
* \rng_arg
3232
* \msg_arg
3333
*/
34-
template <
35-
typename LLFunc, typename LLArgs, typename ThetaVec, typename CovarFun,
36-
typename CovarArgs, typename RNG, require_all_eigen_t<ThetaVec>* = nullptr,
37-
require_t<is_all_arithmetic_scalar<CovarArgs, LLArgs, ThetaVec>>* = nullptr>
38-
inline Eigen::VectorXd laplace_base_rng(LLFunc&& ll_fun, LLArgs&& ll_args,
39-
ThetaVec&& theta_0,
40-
CovarFun&& covariance_function,
41-
CovarArgs&& covar_args,
42-
const laplace_options& options,
43-
RNG& rng, std::ostream* msgs) {
34+
template <typename LLFunc, typename LLArgs, typename CovarFun,
35+
typename CovarArgs, bool InitTheta, typename RNG,
36+
require_t<is_all_arithmetic_scalar<CovarArgs, LLArgs>>* = nullptr>
37+
inline Eigen::VectorXd laplace_base_rng(
38+
LLFunc&& ll_fun, LLArgs&& ll_args, CovarFun&& covariance_function,
39+
CovarArgs&& covar_args, const laplace_options<InitTheta>& options, RNG& rng,
40+
std::ostream* msgs) {
4441
auto md_est = internal::laplace_marginal_density_est(
45-
ll_fun, std::forward<LLArgs>(ll_args), std::forward<ThetaVec>(theta_0),
42+
ll_fun, std::forward<LLArgs>(ll_args),
4643
std::forward<CovarFun>(covariance_function),
4744
to_ref(std::forward<CovarArgs>(covar_args)), options, msgs);
4845
// Modified R&W method

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <stan/math/prim/functor/iter_tuple_nested.hpp>
1313
#include <unsupported/Eigen/MatrixFunctions>
1414
#include <cmath>
15+
#include <optional>
1516

1617
/**
1718
* @file
@@ -26,7 +27,7 @@ namespace math {
2627
/**
2728
* Options for the laplace sampler
2829
*/
29-
struct laplace_options {
30+
struct laplace_options_base {
3031
/* Size of the blocks in block diagonal hessian*/
3132
int hessian_block_size{1};
3233
/**
@@ -46,6 +47,20 @@ struct laplace_options {
4647
int max_num_steps{100};
4748
};
4849

50+
template <bool HasInitTheta>
51+
struct laplace_options;
52+
53+
template <>
54+
struct laplace_options<false> : public laplace_options_base {};
55+
56+
template <>
57+
struct laplace_options<true> : public laplace_options_base {
58+
/* Value for user supplied initial theta */
59+
Eigen::VectorXd theta_0{0};
60+
};
61+
62+
using laplace_options_default = laplace_options<false>;
63+
using laplace_options_user_supplied = laplace_options<true>;
4964
namespace internal {
5065

5166
template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -448,37 +463,46 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
448463
*
449464
*/
450465
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
451-
typename ThetaVec, typename CovarArgs,
452-
require_t<is_all_arithmetic_scalar<ThetaVec, CovarArgs>>* = nullptr,
453-
require_eigen_vector_t<ThetaVec>* = nullptr>
454-
inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
455-
ThetaVec&& theta_0,
456-
CovarFun&& covariance_function,
457-
CovarArgs&& covar_args,
458-
const laplace_options& options,
459-
std::ostream* msgs) {
466+
typename CovarArgs, bool InitTheta,
467+
require_t<is_all_arithmetic_scalar<CovarArgs>>* = nullptr>
468+
inline auto laplace_marginal_density_est(
469+
LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
470+
CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
471+
std::ostream* msgs) {
460472
using Eigen::MatrixXd;
461473
using Eigen::SparseMatrix;
462474
using Eigen::VectorXd;
463-
check_nonzero_size("laplace_marginal", "initial guess", theta_0);
464-
check_finite("laplace_marginal", "initial guess", theta_0);
475+
if constexpr (InitTheta) {
476+
check_nonzero_size("laplace_marginal", "initial guess", options.theta_0);
477+
check_finite("laplace_marginal", "initial guess", options.theta_0);
478+
}
465479
check_nonnegative("laplace_marginal", "tolerance", options.tolerance);
466480
check_positive("laplace_marginal", "max_num_steps", options.max_num_steps);
467481
check_positive("laplace_marginal", "hessian_block_size",
468482
options.hessian_block_size);
469483
check_nonnegative("laplace_marginal", "max_steps_line_search",
470484
options.max_steps_line_search);
471-
if (unlikely(theta_0.size() % options.hessian_block_size != 0)) {
485+
486+
Eigen::MatrixXd covariance = stan::math::apply(
487+
[msgs, &covariance_function](auto&&... args) {
488+
return covariance_function(args..., msgs);
489+
},
490+
covar_args);
491+
check_square("laplace_marginal", "covariance", covariance);
492+
493+
const Eigen::Index theta_size = covariance.rows();
494+
495+
if (unlikely(theta_size % options.hessian_block_size != 0)) {
472496
[&]() STAN_COLD_PATH {
473497
std::stringstream msg;
474-
msg << "laplace_marginal_density: The hessian size (" << theta_0.size()
475-
<< ", " << theta_0.size()
498+
msg << "laplace_marginal_density: The hessian size (" << theta_size
499+
<< ", " << theta_size
476500
<< ") is not divisible by the hessian block size ("
477501
<< options.hessian_block_size
478502
<< ")"
479503
". Try a hessian block size such as [1, ";
480504
for (int i = 2; i < 12; ++i) {
481-
if (theta_0.size() % i == 0) {
505+
if (theta_size % i == 0) {
482506
msg << i << ", ";
483507
}
484508
}
@@ -488,19 +512,20 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
488512
throw std::domain_error(msg.str());
489513
}();
490514
}
491-
Eigen::MatrixXd covariance = stan::math::apply(
492-
[msgs, &covariance_function](auto&&... args) {
493-
return covariance_function(args..., msgs);
494-
},
495-
covar_args);
515+
496516
auto throw_overstep = [](const auto max_num_steps) STAN_COLD_PATH {
497517
throw std::domain_error(
498518
std::string("laplace_marginal_density: max number of iterations: ")
499519
+ std::to_string(max_num_steps) + " exceeded.");
500520
};
501521
auto ll_args_vals = value_of(ll_args);
502-
const Eigen::Index theta_size = theta_0.size();
503-
Eigen::VectorXd theta = std::forward<ThetaVec>(theta_0);
522+
Eigen::VectorXd theta = [theta_size, &options]() {
523+
if constexpr (InitTheta) {
524+
return options.theta_0;
525+
} else {
526+
return Eigen::VectorXd::Zero(theta_size);
527+
}
528+
}();
504529
double objective_old = std::numeric_limits<double>::lowest();
505530
double objective_new = std::numeric_limits<double>::lowest() + 1;
506531
Eigen::VectorXd a_prev = Eigen::VectorXd::Zero(theta_size);
@@ -572,7 +597,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
572597
}
573598
}
574599
} else {
575-
Eigen::SparseMatrix<double> W_r(theta.rows(), theta.rows());
600+
Eigen::SparseMatrix<double> W_r(theta_size, theta_size);
576601
Eigen::Index block_size = options.hessian_block_size;
577602
W_r.reserve(Eigen::VectorXi::Constant(W_r.cols(), block_size));
578603
const Eigen::Index n_block = W_r.cols() / block_size;
@@ -768,20 +793,16 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
768793
* \msg_arg
769794
* @return the log maginal density, p(y | phi)
770795
*/
771-
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
772-
typename ThetaVec, typename CovarArgs,
773-
require_t<is_all_arithmetic_scalar<ThetaVec, CovarArgs,
774-
LLTupleArgs>>* = nullptr,
775-
require_eigen_vector_t<ThetaVec>* = nullptr>
776-
inline double laplace_marginal_density(LLFun&& ll_fun, LLTupleArgs&& ll_args,
777-
ThetaVec&& theta_0,
778-
CovarFun&& covariance_function,
779-
CovarArgs&& covar_args,
780-
const laplace_options& options,
781-
std::ostream* msgs) {
796+
template <
797+
typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
798+
bool InitTheta,
799+
require_t<is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr>
800+
inline double laplace_marginal_density(
801+
LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
802+
CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
803+
std::ostream* msgs) {
782804
return internal::laplace_marginal_density_est(
783805
std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
784-
std::forward<ThetaVec>(theta_0),
785806
std::forward<CovarFun>(covariance_function),
786807
std::forward<CovarArgs>(covar_args), options, msgs)
787808
.lmd;
@@ -1014,16 +1035,13 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10141035
* \msg_arg
10151036
* @return the log maginal density, p(y | phi)
10161037
*/
1017-
template <
1018-
typename LLFun, typename LLTupleArgs, typename CovarFun, typename ThetaVec,
1019-
typename CovarArgs,
1020-
require_t<is_any_var_scalar<ThetaVec, LLTupleArgs, CovarArgs>>* = nullptr,
1021-
require_eigen_vector_t<ThetaVec>* = nullptr>
1038+
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1039+
typename CovarArgs, bool InitTheta,
1040+
require_t<is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr>
10221041
inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
1023-
ThetaVec&& theta_0,
10241042
CovarFun&& covariance_function,
10251043
CovarArgs&& covar_args,
1026-
const laplace_options& options,
1044+
const laplace_options<InitTheta>& options,
10271045
std::ostream* msgs) {
10281046
auto covar_args_refs = to_ref(std::forward<CovarArgs>(covar_args));
10291047
auto ll_args_refs = to_ref(std::forward<LLTupleArgs>(ll_args));
@@ -1034,13 +1052,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10341052
double lmd = 0.0;
10351053
{
10361054
nested_rev_autodiff nested;
1037-
// Solver 1, 2
1038-
arena_t<Eigen::MatrixXd> R(theta_0.size(), theta_0.size());
1039-
// Solver 3
1040-
arena_t<Eigen::MatrixXd> LU_solve_covariance;
1041-
// Solver 1, 2, 3
1042-
arena_t<promote_scalar_t<double, plain_type_t<std::decay_t<ThetaVec>>>> s2(
1043-
theta_0.size());
1055+
10441056
// Make one hard copy here
10451057
using laplace_likelihood::internal::conditional_copy_and_promote;
10461058
using laplace_likelihood::internal::COPY_TYPE;
@@ -1049,8 +1061,16 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10491061
ll_args_refs);
10501062

10511063
auto md_est = internal::laplace_marginal_density_est(
1052-
ll_fun, ll_args_copy, value_of(theta_0), covariance_function,
1053-
value_of(covar_args_refs), options, msgs);
1064+
ll_fun, ll_args_copy, covariance_function, value_of(covar_args_refs),
1065+
options, msgs);
1066+
1067+
// Solver 1, 2
1068+
arena_t<Eigen::MatrixXd> R(md_est.theta.size(), md_est.theta.size());
1069+
// Solver 3
1070+
arena_t<Eigen::MatrixXd> LU_solve_covariance;
1071+
// Solver 1, 2, 3
1072+
arena_t<Eigen::VectorXd> s2(md_est.theta.size());
1073+
10541074
// Return references to var types
10551075
auto ll_args_filter = internal::filter_var_scalar_types(ll_args_copy);
10561076
stan::math::for_each(

stan/math/mix/prob/laplace_latent_bernoulli_logit_rng.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ namespace math {
1717
* return a multivariate normal random variate sampled
1818
* from the gaussian approximation of p(theta | y, phi),
1919
* where the likelihood is a Bernoulli with logit link.
20+
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
21+
* with dynamic sized rows and 1 column.
2022
* \laplace_common_template_args
2123
* @tparam RNG A valid boost rng type
2224
* @param[in] y Vector Vector of total number of trials with a positive outcome.
@@ -30,15 +32,15 @@ template <typename ThetaVec, typename CovarFun, typename CovarArgs,
3032
typename RNG, require_eigen_t<ThetaVec>* = nullptr>
3133
inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
3234
const std::vector<int>& y, const std::vector<int>& n_samples,
33-
ThetaVec&& theta_0, CovarFun&& covariance_function, CovarArgs&& covar_args,
35+
CovarFun&& covariance_function, CovarArgs&& covar_args, ThetaVec&& theta_0,
3436
const double tolerance, const int max_num_steps,
3537
const int hessian_block_size, const int solver,
3638
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
37-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
38-
tolerance, max_num_steps};
39+
laplace_options_user_supplied ops{hessian_block_size, solver,
40+
max_steps_line_search, tolerance,
41+
max_num_steps, value_of(theta_0)};
3942
return laplace_base_rng(bernoulli_logit_likelihood{},
4043
std::forward_as_tuple(to_vector(y), n_samples),
41-
std::forward<ThetaVec>(theta_0),
4244
std::forward<CovarFun>(covariance_function),
4345
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
4446
}
@@ -60,18 +62,16 @@ inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
6062
* \rng_arg
6163
* \msg_arg
6264
*/
63-
template <typename CovarFun, typename ThetaVec, typename CovarArgs,
64-
typename RNG, require_eigen_t<ThetaVec>* = nullptr>
65+
template <typename CovarFun, typename CovarArgs, typename RNG>
6566
inline Eigen::VectorXd laplace_latent_bernoulli_logit_rng(
6667
const std::vector<int>& y, const std::vector<int>& n_samples,
67-
ThetaVec&& theta_0, CovarFun&& covariance_function, CovarArgs&& covar_args,
68-
RNG& rng, std::ostream* msgs) {
69-
constexpr laplace_options ops{1, 1, 0, 1e-6, 100};
68+
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
69+
std::ostream* msgs) {
7070
return laplace_base_rng(bernoulli_logit_likelihood{},
7171
std::forward_as_tuple(to_vector(y), n_samples),
72-
std::forward<ThetaVec>(theta_0),
7372
std::forward<CovarFun>(covariance_function),
74-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
73+
std::forward<CovarArgs>(covar_args),
74+
laplace_options_default{}, rng, msgs);
7575
}
7676

7777
} // namespace math

stan/math/mix/prob/laplace_latent_neg_binomial_2_log_rng.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace math {
2121
* parameterization of the Negative Binomial.
2222
*
2323
* @tparam Eta A type for the overdispersion parameter.
24+
* @tparam ThetaVec A type inheriting from `Eigen::EigenBase`
25+
* with dynamic sized rows and 1 column.
2426
* \laplace_common_template_args
2527
* @tparam RNG A valid boost rng type
2628
* @param[in] y Observed counts.
@@ -36,16 +38,16 @@ template <typename Eta, typename ThetaVec, typename CovarFun,
3638
require_eigen_t<ThetaVec>* = nullptr>
3739
inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
3840
const std::vector<int>& y, const std::vector<int>& y_index, Eta&& eta,
39-
ThetaVec&& theta_0, CovarFun&& covariance_function, CovarArgs&& covar_args,
41+
CovarFun&& covariance_function, CovarArgs&& covar_args, ThetaVec&& theta_0,
4042
const double tolerance, const int max_num_steps,
4143
const int hessian_block_size, const int solver,
4244
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
43-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
44-
tolerance, max_num_steps};
45+
laplace_options_user_supplied ops{hessian_block_size, solver,
46+
max_steps_line_search, tolerance,
47+
max_num_steps, value_of(theta_0)};
4548
return laplace_base_rng(
4649
neg_binomial_2_log_likelihood{},
4750
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index),
48-
std::forward<ThetaVec>(theta_0),
4951
std::forward<CovarFun>(covariance_function),
5052
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
5153
}
@@ -72,20 +74,17 @@ inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
7274
* \rng_arg
7375
* \msg_arg
7476
*/
75-
template <typename Eta, typename ThetaVec, typename CovarFun,
76-
typename CovarArgs, typename RNG,
77-
require_eigen_t<ThetaVec>* = nullptr>
77+
template <typename Eta, typename CovarFun, typename CovarArgs, typename RNG>
7878
inline Eigen::VectorXd laplace_latent_neg_binomial_2_log_rng(
7979
const std::vector<int>& y, const std::vector<int>& y_index, Eta&& eta,
80-
ThetaVec&& theta_0, CovarFun&& covariance_function, CovarArgs&& covar_args,
81-
RNG& rng, std::ostream* msgs) {
82-
constexpr laplace_options ops{1, 1, 0, 1e-6, 100};
80+
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
81+
std::ostream* msgs) {
8382
return laplace_base_rng(
8483
neg_binomial_2_log_likelihood{},
8584
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index),
86-
std::forward<ThetaVec>(theta_0),
8785
std::forward<CovarFun>(covariance_function),
88-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
86+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng,
87+
msgs);
8988
}
9089

9190
} // namespace math

0 commit comments

Comments
 (0)