Skip to content

Commit 94ab922

Browse files
authored
Merge pull request #3213 from stan-dev/laplace/compile-time-theta-default
Compile Time Theta Default
2 parents 699fb35 + 84f46d6 commit 94ab922

12 files changed

+56
-52
lines changed

stan/math/mix/functor/laplace_base_rng.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ namespace math {
3232
* \msg_arg
3333
*/
3434
template <typename LLFunc, typename LLArgs, typename CovarFun,
35-
typename CovarArgs, typename RNG,
35+
typename CovarArgs, bool InitTheta, typename RNG,
3636
require_t<is_all_arithmetic_scalar<CovarArgs, LLArgs>>* = nullptr>
3737
inline Eigen::VectorXd laplace_base_rng(LLFunc&& ll_fun, LLArgs&& ll_args,
3838
CovarFun&& covariance_function,
3939
CovarArgs&& covar_args,
40-
const laplace_options& options,
40+
const laplace_options<InitTheta>& options,
4141
RNG& rng, std::ostream* msgs) {
4242
auto md_est = internal::laplace_marginal_density_est(
4343
ll_fun, std::forward<LLArgs>(ll_args),

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace math {
2727
/**
2828
* Options for the laplace sampler
2929
*/
30-
struct laplace_options {
30+
struct laplace_options_base {
3131
/* Size of the blocks in block diagonal hessian*/
3232
int hessian_block_size{1};
3333
/**
@@ -45,11 +45,23 @@ struct laplace_options {
4545
double tolerance{1e-6};
4646
/* Maximum number of steps*/
4747
int max_num_steps{100};
48+
};
49+
50+
template <bool HasInitTheta>
51+
struct laplace_options;
4852

49-
/* Initial value for theta. Defaults to 0s of the correct size if nullopt */
50-
std::optional<Eigen::VectorXd> theta_0{std::nullopt};
53+
template <>
54+
struct laplace_options<false> : public laplace_options_base {};
55+
56+
template <>
57+
struct laplace_options<true> : public laplace_options_base {
58+
/* Value for user supplied initial theta */
59+
Eigen::VectorXd theta_0{0};
5160
};
5261

62+
63+
using laplace_options_default = laplace_options<false>;
64+
using laplace_options_user_supplied = laplace_options<true>;
5365
namespace internal {
5466

5567
template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -452,21 +464,20 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
452464
*
453465
*/
454466
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
455-
typename CovarArgs,
467+
typename CovarArgs, bool InitTheta,
456468
require_t<is_all_arithmetic_scalar<CovarArgs>>* = nullptr>
457469
inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
458470
CovarFun&& covariance_function,
459471
CovarArgs&& covar_args,
460-
const laplace_options& options,
472+
const laplace_options<InitTheta>& options,
461473
std::ostream* msgs) {
462474
using Eigen::MatrixXd;
463475
using Eigen::SparseMatrix;
464476
using Eigen::VectorXd;
465-
if (options.theta_0.has_value()) {
466-
check_nonzero_size("laplace_marginal", "initial guess", *options.theta_0);
467-
check_finite("laplace_marginal", "initial guess", *options.theta_0);
477+
if constexpr (InitTheta) {
478+
check_nonzero_size("laplace_marginal", "initial guess", options.theta_0);
479+
check_finite("laplace_marginal", "initial guess", options.theta_0);
468480
}
469-
470481
check_nonnegative("laplace_marginal", "tolerance", options.tolerance);
471482
check_positive("laplace_marginal", "max_num_steps", options.max_num_steps);
472483
check_positive("laplace_marginal", "hessian_block_size",
@@ -510,9 +521,13 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
510521
+ std::to_string(max_num_steps) + " exceeded.");
511522
};
512523
auto ll_args_vals = value_of(ll_args);
513-
Eigen::VectorXd theta = options.theta_0.has_value()
514-
? *options.theta_0
515-
: Eigen::VectorXd::Zero(theta_size);
524+
Eigen::VectorXd theta = [theta_size, &options]() {
525+
if constexpr (InitTheta) {
526+
return options.theta_0;
527+
} else {
528+
return Eigen::VectorXd::Zero(theta_size);
529+
}
530+
}();
516531
double objective_old = std::numeric_limits<double>::lowest();
517532
double objective_new = std::numeric_limits<double>::lowest() + 1;
518533
Eigen::VectorXd a_prev = Eigen::VectorXd::Zero(theta_size);
@@ -584,7 +599,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
584599
}
585600
}
586601
} else {
587-
Eigen::SparseMatrix<double> W_r(theta.rows(), theta.rows());
602+
Eigen::SparseMatrix<double> W_r(theta_size, theta_size);
588603
Eigen::Index block_size = options.hessian_block_size;
589604
W_r.reserve(Eigen::VectorXi::Constant(W_r.cols(), block_size));
590605
const Eigen::Index n_block = W_r.cols() / block_size;
@@ -781,12 +796,12 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
781796
* @return the log maginal density, p(y | phi)
782797
*/
783798
template <
784-
typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
799+
typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs, bool InitTheta,
785800
require_t<is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr>
786801
inline double laplace_marginal_density(LLFun&& ll_fun, LLTupleArgs&& ll_args,
787802
CovarFun&& covariance_function,
788803
CovarArgs&& covar_args,
789-
const laplace_options& options,
804+
const laplace_options<InitTheta>& options,
790805
std::ostream* msgs) {
791806
return internal::laplace_marginal_density_est(
792807
std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
@@ -1023,12 +1038,12 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10231038
* @return the log maginal density, p(y | phi)
10241039
*/
10251040
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1026-
typename CovarArgs,
1041+
typename CovarArgs, bool InitTheta,
10271042
require_t<is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr>
10281043
inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10291044
CovarFun&& covariance_function,
10301045
CovarArgs&& covar_args,
1031-
const laplace_options& options,
1046+
const laplace_options<InitTheta>& options,
10321047
std::ostream* msgs) {
10331048
auto covar_args_refs = to_ref(std::forward<CovarArgs>(covar_args));
10341049
auto ll_args_refs = to_ref(std::forward<LLTupleArgs>(ll_args));

stan/math/mix/prob/laplace_latent_bernoulli_logit_rng.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline Eigen::VectorXd laplace_latent_tol_bernoulli_logit_rng(
3636
const double tolerance, const int max_num_steps,
3737
const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
39+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4040
tolerance, max_num_steps, value_of(theta_0)};
4141
return laplace_base_rng(bernoulli_logit_likelihood{},
4242
std::forward_as_tuple(to_vector(y), n_samples),
@@ -66,11 +66,10 @@ inline Eigen::VectorXd laplace_latent_bernoulli_logit_rng(
6666
const std::vector<int>& y, const std::vector<int>& n_samples,
6767
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
6868
std::ostream* msgs) {
69-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
7069
return laplace_base_rng(bernoulli_logit_likelihood{},
7170
std::forward_as_tuple(to_vector(y), n_samples),
7271
std::forward<CovarFun>(covariance_function),
73-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
72+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng, msgs);
7473
}
7574

7675
} // namespace math

stan/math/mix/prob/laplace_latent_neg_binomial_2_log_rng.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ inline Eigen::VectorXd laplace_latent_tol_neg_binomial_2_log_rng(
4242
const double tolerance, const int max_num_steps,
4343
const int hessian_block_size, const int solver,
4444
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
45-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
45+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4646
tolerance, max_num_steps, value_of(theta_0)};
4747
return laplace_base_rng(
4848
neg_binomial_2_log_likelihood{},
@@ -78,12 +78,11 @@ inline Eigen::VectorXd laplace_latent_neg_binomial_2_log_rng(
7878
const std::vector<int>& y, const std::vector<int>& y_index, Eta&& eta,
7979
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
8080
std::ostream* msgs) {
81-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
8281
return laplace_base_rng(
8382
neg_binomial_2_log_likelihood{},
8483
std::forward_as_tuple(std::forward<Eta>(eta), y, y_index),
8584
std::forward<CovarFun>(covariance_function),
86-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
85+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng, msgs);
8786
}
8887

8988
} // namespace math

stan/math/mix/prob/laplace_latent_poisson_log_2_rng.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ inline auto laplace_latent_tol_poisson_2_log_rng(
3838
const double tolerance, const int max_num_steps,
3939
const int hessian_block_size, const int solver,
4040
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
41-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
41+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4242
tolerance, max_num_steps, value_of(theta_0)};
4343
return laplace_base_rng(poisson_log_2_likelihood{},
4444
std::forward_as_tuple(y, y_index, ye),
@@ -71,11 +71,10 @@ inline auto laplace_latent_poisson_2_log_rng(const std::vector<int>& y,
7171
CovarFun&& covariance_function,
7272
CovarArgs&& covar_args, RNG& rng,
7373
std::ostream* msgs) {
74-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
7574
return laplace_base_rng(poisson_log_2_likelihood{},
7675
std::forward_as_tuple(y, y_index, ye),
7776
std::forward<CovarFun>(covariance_function),
78-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
77+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng, msgs);
7978
}
8079

8180
} // namespace math

stan/math/mix/prob/laplace_latent_poisson_log_rng.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline Eigen::VectorXd laplace_latent_tol_poisson_log_rng(
3636
const double tolerance, const int max_num_steps,
3737
const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
39+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
4040
tolerance, max_num_steps, value_of(theta_0)};
4141
return laplace_base_rng(poisson_log_likelihood{},
4242
std::forward_as_tuple(y, y_index),
@@ -67,11 +67,10 @@ inline Eigen::VectorXd laplace_latent_poisson_log_rng(
6767
const std::vector<int>& y, const std::vector<int>& y_index,
6868
CovarFun&& covariance_function, CovarArgs&& covar_args, RNG& rng,
6969
std::ostream* msgs) {
70-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
7170
return laplace_base_rng(poisson_log_likelihood{},
7271
std::forward_as_tuple(y, y_index),
7372
std::forward<CovarFun>(covariance_function),
74-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
73+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng, msgs);
7574
}
7675

7776
} // namespace math

stan/math/mix/prob/laplace_latent_rng.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ inline auto laplace_latent_tol_rng(
3636
CovarArgs&& covar_args, ThetaVec&& theta_0, const double tolerance,
3737
const int max_num_steps, const int hessian_block_size, const int solver,
3838
const int max_steps_line_search, RNG& rng, std::ostream* msgs) {
39-
const laplace_options ops{hessian_block_size, solver,
39+
const laplace_options_user_supplied ops{hessian_block_size, solver,
4040
max_steps_line_search, tolerance,
4141
max_num_steps, value_of(theta_0)};
4242
return laplace_base_rng(std::forward<LLFunc>(L_f),
@@ -69,11 +69,10 @@ inline auto laplace_latent_rng(LLFunc&& L_f, LLArgs&& ll_args,
6969
CovarFun&& covariance_function,
7070
CovarArgs&& covar_args, RNG& rng,
7171
std::ostream* msgs) {
72-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
7372
return laplace_base_rng(std::forward<LLFunc>(L_f),
7473
std::forward<LLArgs>(ll_args),
7574
std::forward<CovarFun>(covariance_function),
76-
std::forward<CovarArgs>(covar_args), ops, rng, msgs);
75+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, rng, msgs);
7776
}
7877

7978
} // namespace math

stan/math/mix/prob/laplace_marginal.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ inline auto laplace_marginal_tol(
3333
CovarArgs&& covar_args, const ThetaVec& theta_0, double tolerance,
3434
int max_num_steps, const int hessian_block_size, const int solver,
3535
const int max_steps_line_search, std::ostream* msgs) {
36-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
36+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
3737
tolerance, max_num_steps, value_of(theta_0)};
3838
return laplace_marginal_density(
3939
std::forward<LFun>(L_f), std::forward<LArgs>(l_args),
@@ -63,11 +63,10 @@ template <bool propto = false, typename LFun, typename LArgs, typename CovarFun,
6363
inline auto laplace_marginal(LFun&& L_f, LArgs&& l_args,
6464
CovarFun&& covariance_function,
6565
CovarArgs&& covar_args, std::ostream* msgs) {
66-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
6766
return laplace_marginal_density(
6867
std::forward<LFun>(L_f), std::forward<LArgs>(l_args),
6968
std::forward<CovarFun>(covariance_function),
70-
std::forward<CovarArgs>(covar_args), ops, msgs);
69+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, msgs);
7170
}
7271

7372
} // namespace math

stan/math/mix/prob/laplace_marginal_bernoulli_logit_lpmf.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ inline auto laplace_marginal_tol_bernoulli_logit_lpmf(
5555
const ThetaVec& theta_0, double tolerance, int max_num_steps,
5656
const int hessian_block_size, const int solver,
5757
const int max_steps_line_search, std::ostream* msgs) {
58-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
58+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
5959
tolerance, max_num_steps, value_of(theta_0)};
6060
return laplace_marginal_density(
6161
bernoulli_logit_likelihood{},
@@ -84,12 +84,11 @@ inline auto laplace_marginal_bernoulli_logit_lpmf(
8484
const std::vector<int>& y, const std::vector<int>& n_samples,
8585
CovarFun&& covariance_function, CovarArgs&& covar_args,
8686
std::ostream* msgs) {
87-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
8887
return laplace_marginal_density(
8988
bernoulli_logit_likelihood{},
9089
std::forward_as_tuple(to_vector(y), n_samples),
9190
std::forward<CovarFun>(covariance_function),
92-
std::forward<CovarArgs>(covar_args), ops, msgs);
91+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, msgs);
9392
}
9493

9594
} // namespace math

stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ inline auto laplace_marginal_tol_neg_binomial_2_log_lpmf(
7373
const ThetaVec& theta_0, double tolerance, int max_num_steps,
7474
const int hessian_block_size, const int solver,
7575
const int max_steps_line_search, std::ostream* msgs) {
76-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
76+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
7777
tolerance, max_num_steps, value_of(theta_0)};
7878
return laplace_marginal_density(
7979
neg_binomial_2_log_likelihood{}, std::forward_as_tuple(eta, y, y_index),
@@ -103,11 +103,10 @@ inline auto laplace_marginal_neg_binomial_2_log_lpmf(
103103
const std::vector<int>& y, const std::vector<int>& y_index, const Eta& eta,
104104
CovarFun&& covariance_function, CovarArgs&& covar_args,
105105
std::ostream* msgs) {
106-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
107106
return laplace_marginal_density(
108107
neg_binomial_2_log_likelihood{}, std::forward_as_tuple(eta, y, y_index),
109108
std::forward<CovarFun>(covariance_function),
110-
std::forward<CovarArgs>(covar_args), ops, msgs);
109+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, msgs);
111110
}
112111

113112
struct neg_binomial_2_log_likelihood_summary {
@@ -159,7 +158,7 @@ inline auto laplace_marginal_tol_neg_binomial_2_log_summary_lpmf(
159158
const ThetaVec& theta_0, double tolerance, int max_num_steps,
160159
const int hessian_block_size, const int solver,
161160
const int max_steps_line_search, std::ostream* msgs) {
162-
laplace_options ops{hessian_block_size, solver, max_steps_line_search,
161+
laplace_options_user_supplied ops{hessian_block_size, solver, max_steps_line_search,
163162
tolerance, max_num_steps, value_of(theta_0)};
164163
return laplace_marginal_density(
165164
neg_binomial_2_log_likelihood_summary{},
@@ -191,12 +190,11 @@ inline auto laplace_marginal_neg_binomial_2_log_summary_lpmf(
191190
const std::vector<int>& counts_per_group, const Eta& eta,
192191
CovarFun&& covariance_function, CovarArgs&& covar_args,
193192
std::ostream* msgs) {
194-
const laplace_options ops{1, 1, 0, 1e-6, 100, std::nullopt};
195193
return laplace_marginal_density(
196194
neg_binomial_2_log_likelihood_summary{},
197195
std::forward_as_tuple(eta, y, n_per_group, counts_per_group),
198196
std::forward<CovarFun>(covariance_function),
199-
std::forward<CovarArgs>(covar_args), ops, msgs);
197+
std::forward<CovarArgs>(covar_args), laplace_options_default{}, msgs);
200198
}
201199

202200
} // namespace math

0 commit comments

Comments
 (0)