Skip to content

Commit f748825

Browse files
committed
update docs for new vari for csr_matrix_times_vector
1 parent a3a88a5 commit f748825

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

stan/math/rev/fun/csr_matrix_times_vector.hpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@ namespace stan {
1212
namespace math {
1313

1414
namespace internal {
15+
/**
16+
* `vari` for csr_matrix_times_vector
17+
* @note `csr_matrix_times_vector` uses the old inheritance
18+
* style to set up the reverse pass because of a linking
19+
* issue on windows when using flto.
20+
*
21+
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
22+
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
23+
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
24+
*
25+
*/
1526
template <typename Result_, typename WMat_, typename B_>
1627
struct csr_adjoint : public vari {
1728
std::decay_t<Result_> res_;
1829
std::decay_t<WMat_> w_mat_;
1930
std::decay_t<B_> b_;
31+
2032
template <typename T1, typename T2, typename T3>
2133
csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
2234
: vari(0.0),
@@ -25,33 +37,73 @@ struct csr_adjoint : public vari {
2537
b_(std::forward<T3>(b)) {}
2638

2739
void chain() { chain_internal(res_, w_mat_, b_); }
40+
41+
/**
42+
* Overload for calculating adjoints of `w_mat` and `b`
43+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
44+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
45+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
46+
* @param res The vector result of the forward pass calculation
47+
* @param w_mat A sparse matrix
48+
* @param b A vector
49+
*/
2850
template <typename Result, typename WMat, typename B,
2951
require_rev_matrix_t<WMat>* = nullptr,
3052
require_rev_matrix_t<B>* = nullptr>
31-
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
53+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
3254
w_mat.adj() += res.adj() * b.val().transpose();
3355
b.adj() += w_mat.val().transpose() * res.adj();
3456
}
3557

58+
/**
59+
* Overload for calculating adjoints of `w_mat`
60+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
61+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
62+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `double`
63+
* @param res The vector result of the forward pass calculation
64+
* @param w_mat A sparse matrix
65+
* @param b A vector
66+
*/
3667
template <typename Result, typename WMat, typename B,
3768
require_rev_matrix_t<WMat>* = nullptr,
3869
require_not_rev_matrix_t<B>* = nullptr>
39-
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
70+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
4071
w_mat.adj() += res.adj() * b.transpose();
4172
}
4273

74+
/**
75+
* Overload for calculating adjoints of `b`
76+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
77+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar type `double`
78+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
79+
* @param res The vector result of the forward pass calculation
80+
* @param w_mat A sparse matrix
81+
* @param b A vector
82+
*/
4383
template <typename Result, typename WMat, typename B,
4484
require_not_rev_matrix_t<WMat>* = nullptr,
4585
require_rev_matrix_t<B>* = nullptr>
46-
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
86+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
4787
b.adj() += w_mat.transpose() * res.adj();
4888
}
4989
};
90+
91+
/**
92+
* Helper function to construct the csr_adjoint struct.
93+
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
94+
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
95+
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
96+
*
97+
* @param res The vector result of the forward pass calculation
98+
* @param w_mat A sparse matrix
99+
* @param b A vector
100+
*/
50101
template <typename Result_, typename WMat_, typename B_>
51-
inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
52-
return new csr_adjoint<Result_, WMat_, B_>(std::forward<Result_>(res),
102+
inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
103+
new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>(std::forward<Result_>(res),
53104
std::forward<WMat_>(w_mat),
54105
std::forward<B_>(b));
106+
return;
55107
}
56108
} // namespace internal
57109

0 commit comments

Comments
 (0)