12
12
#include < stan/math/prim/functor/iter_tuple_nested.hpp>
13
13
#include < unsupported/Eigen/MatrixFunctions>
14
14
#include < cmath>
15
+ #include < optional>
15
16
16
17
/* *
17
18
* @file
@@ -26,7 +27,7 @@ namespace math {
26
27
/* *
27
28
* Options for the laplace sampler
28
29
*/
29
- struct laplace_options {
30
+ struct laplace_options_base {
30
31
/* Size of the blocks in block diagonal hessian*/
31
32
int hessian_block_size{1 };
32
33
/* *
@@ -46,6 +47,20 @@ struct laplace_options {
46
47
int max_num_steps{100 };
47
48
};
48
49
50
+ template <bool HasInitTheta>
51
+ struct laplace_options ;
52
+
53
+ template <>
54
+ struct laplace_options <false > : public laplace_options_base {};
55
+
56
+ template <>
57
+ struct laplace_options <true > : public laplace_options_base {
58
+ /* Value for user supplied initial theta */
59
+ Eigen::VectorXd theta_0{0 };
60
+ };
61
+
62
+ using laplace_options_default = laplace_options<false >;
63
+ using laplace_options_user_supplied = laplace_options<true >;
49
64
namespace internal {
50
65
51
66
template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -448,37 +463,46 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
448
463
*
449
464
*/
450
465
template <typename LLFun, typename LLTupleArgs, typename CovarFun,
451
- typename ThetaVec, typename CovarArgs,
452
- require_t <is_all_arithmetic_scalar<ThetaVec, CovarArgs>>* = nullptr ,
453
- require_eigen_vector_t <ThetaVec>* = nullptr >
454
- inline auto laplace_marginal_density_est (LLFun&& ll_fun, LLTupleArgs&& ll_args,
455
- ThetaVec&& theta_0,
456
- CovarFun&& covariance_function,
457
- CovarArgs&& covar_args,
458
- const laplace_options& options,
459
- std::ostream* msgs) {
466
+ typename CovarArgs, bool InitTheta,
467
+ require_t <is_all_arithmetic_scalar<CovarArgs>>* = nullptr >
468
+ inline auto laplace_marginal_density_est (
469
+ LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
470
+ CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
471
+ std::ostream* msgs) {
460
472
using Eigen::MatrixXd;
461
473
using Eigen::SparseMatrix;
462
474
using Eigen::VectorXd;
463
- check_nonzero_size (" laplace_marginal" , " initial guess" , theta_0);
464
- check_finite (" laplace_marginal" , " initial guess" , theta_0);
475
+ if constexpr (InitTheta) {
476
+ check_nonzero_size (" laplace_marginal" , " initial guess" , options.theta_0 );
477
+ check_finite (" laplace_marginal" , " initial guess" , options.theta_0 );
478
+ }
465
479
check_nonnegative (" laplace_marginal" , " tolerance" , options.tolerance );
466
480
check_positive (" laplace_marginal" , " max_num_steps" , options.max_num_steps );
467
481
check_positive (" laplace_marginal" , " hessian_block_size" ,
468
482
options.hessian_block_size );
469
483
check_nonnegative (" laplace_marginal" , " max_steps_line_search" ,
470
484
options.max_steps_line_search );
471
- if (unlikely (theta_0.size () % options.hessian_block_size != 0 )) {
485
+
486
+ Eigen::MatrixXd covariance = stan::math::apply (
487
+ [msgs, &covariance_function](auto &&... args) {
488
+ return covariance_function (args..., msgs);
489
+ },
490
+ covar_args);
491
+ check_square (" laplace_marginal" , " covariance" , covariance);
492
+
493
+ const Eigen::Index theta_size = covariance.rows ();
494
+
495
+ if (unlikely (theta_size % options.hessian_block_size != 0 )) {
472
496
[&]() STAN_COLD_PATH {
473
497
std::stringstream msg;
474
- msg << " laplace_marginal_density: The hessian size (" << theta_0. size ()
475
- << " , " << theta_0. size ()
498
+ msg << " laplace_marginal_density: The hessian size (" << theta_size
499
+ << " , " << theta_size
476
500
<< " ) is not divisible by the hessian block size ("
477
501
<< options.hessian_block_size
478
502
<< " )"
479
503
" . Try a hessian block size such as [1, " ;
480
504
for (int i = 2 ; i < 12 ; ++i) {
481
- if (theta_0. size () % i == 0 ) {
505
+ if (theta_size % i == 0 ) {
482
506
msg << i << " , " ;
483
507
}
484
508
}
@@ -488,19 +512,20 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
488
512
throw std::domain_error (msg.str ());
489
513
}();
490
514
}
491
- Eigen::MatrixXd covariance = stan::math::apply (
492
- [msgs, &covariance_function](auto &&... args) {
493
- return covariance_function (args..., msgs);
494
- },
495
- covar_args);
515
+
496
516
auto throw_overstep = [](const auto max_num_steps) STAN_COLD_PATH {
497
517
throw std::domain_error (
498
518
std::string (" laplace_marginal_density: max number of iterations: " )
499
519
+ std::to_string (max_num_steps) + " exceeded." );
500
520
};
501
521
auto ll_args_vals = value_of (ll_args);
502
- const Eigen::Index theta_size = theta_0.size ();
503
- Eigen::VectorXd theta = std::forward<ThetaVec>(theta_0);
522
+ Eigen::VectorXd theta = [theta_size, &options]() {
523
+ if constexpr (InitTheta) {
524
+ return options.theta_0 ;
525
+ } else {
526
+ return Eigen::VectorXd::Zero (theta_size);
527
+ }
528
+ }();
504
529
double objective_old = std::numeric_limits<double >::lowest ();
505
530
double objective_new = std::numeric_limits<double >::lowest () + 1 ;
506
531
Eigen::VectorXd a_prev = Eigen::VectorXd::Zero (theta_size);
@@ -572,7 +597,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
572
597
}
573
598
}
574
599
} else {
575
- Eigen::SparseMatrix<double > W_r (theta. rows (), theta. rows () );
600
+ Eigen::SparseMatrix<double > W_r (theta_size, theta_size );
576
601
Eigen::Index block_size = options.hessian_block_size ;
577
602
W_r.reserve (Eigen::VectorXi::Constant (W_r.cols (), block_size));
578
603
const Eigen::Index n_block = W_r.cols () / block_size;
@@ -768,20 +793,16 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
768
793
* \msg_arg
769
794
* @return the log maginal density, p(y | phi)
770
795
*/
771
- template <typename LLFun, typename LLTupleArgs, typename CovarFun,
772
- typename ThetaVec, typename CovarArgs,
773
- require_t <is_all_arithmetic_scalar<ThetaVec, CovarArgs,
774
- LLTupleArgs>>* = nullptr ,
775
- require_eigen_vector_t <ThetaVec>* = nullptr >
776
- inline double laplace_marginal_density (LLFun&& ll_fun, LLTupleArgs&& ll_args,
777
- ThetaVec&& theta_0,
778
- CovarFun&& covariance_function,
779
- CovarArgs&& covar_args,
780
- const laplace_options& options,
781
- std::ostream* msgs) {
796
+ template <
797
+ typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
798
+ bool InitTheta,
799
+ require_t <is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr >
800
+ inline double laplace_marginal_density (
801
+ LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
802
+ CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
803
+ std::ostream* msgs) {
782
804
return internal::laplace_marginal_density_est (
783
805
std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
784
- std::forward<ThetaVec>(theta_0),
785
806
std::forward<CovarFun>(covariance_function),
786
807
std::forward<CovarArgs>(covar_args), options, msgs)
787
808
.lmd ;
@@ -1014,16 +1035,13 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
1014
1035
* \msg_arg
1015
1036
* @return the log maginal density, p(y | phi)
1016
1037
*/
1017
- template <
1018
- typename LLFun, typename LLTupleArgs, typename CovarFun, typename ThetaVec,
1019
- typename CovarArgs,
1020
- require_t <is_any_var_scalar<ThetaVec, LLTupleArgs, CovarArgs>>* = nullptr ,
1021
- require_eigen_vector_t <ThetaVec>* = nullptr >
1038
+ template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1039
+ typename CovarArgs, bool InitTheta,
1040
+ require_t <is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr >
1022
1041
inline auto laplace_marginal_density (const LLFun& ll_fun, LLTupleArgs&& ll_args,
1023
- ThetaVec&& theta_0,
1024
1042
CovarFun&& covariance_function,
1025
1043
CovarArgs&& covar_args,
1026
- const laplace_options& options,
1044
+ const laplace_options<InitTheta> & options,
1027
1045
std::ostream* msgs) {
1028
1046
auto covar_args_refs = to_ref (std::forward<CovarArgs>(covar_args));
1029
1047
auto ll_args_refs = to_ref (std::forward<LLTupleArgs>(ll_args));
@@ -1034,13 +1052,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
1034
1052
double lmd = 0.0 ;
1035
1053
{
1036
1054
nested_rev_autodiff nested;
1037
- // Solver 1, 2
1038
- arena_t <Eigen::MatrixXd> R (theta_0.size (), theta_0.size ());
1039
- // Solver 3
1040
- arena_t <Eigen::MatrixXd> LU_solve_covariance;
1041
- // Solver 1, 2, 3
1042
- arena_t <promote_scalar_t <double , plain_type_t <std::decay_t <ThetaVec>>>> s2 (
1043
- theta_0.size ());
1055
+
1044
1056
// Make one hard copy here
1045
1057
using laplace_likelihood::internal::conditional_copy_and_promote;
1046
1058
using laplace_likelihood::internal::COPY_TYPE;
@@ -1049,8 +1061,16 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
1049
1061
ll_args_refs);
1050
1062
1051
1063
auto md_est = internal::laplace_marginal_density_est (
1052
- ll_fun, ll_args_copy, value_of (theta_0), covariance_function,
1053
- value_of (covar_args_refs), options, msgs);
1064
+ ll_fun, ll_args_copy, covariance_function, value_of (covar_args_refs),
1065
+ options, msgs);
1066
+
1067
+ // Solver 1, 2
1068
+ arena_t <Eigen::MatrixXd> R (md_est.theta .size (), md_est.theta .size ());
1069
+ // Solver 3
1070
+ arena_t <Eigen::MatrixXd> LU_solve_covariance;
1071
+ // Solver 1, 2, 3
1072
+ arena_t <Eigen::VectorXd> s2 (md_est.theta .size ());
1073
+
1054
1074
// Return references to var types
1055
1075
auto ll_args_filter = internal::filter_var_scalar_types (ll_args_copy);
1056
1076
stan::math::for_each (
0 commit comments