Skip to content

Commit 9f759e1

Browse files
authored
Merge pull request #3053 from stan-dev/fix/csr-matrix-seperate-vari
use a seperate class for csr_matrix adjoint
2 parents 08d8a22 + 04124da commit 9f759e1

File tree

1 file changed

+116
-10
lines changed

1 file changed

+116
-10
lines changed

stan/math/rev/fun/csr_matrix_times_vector.hpp

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,119 @@
1111
namespace stan {
1212
namespace math {
1313

14+
namespace internal {
15+
/**
16+
* `vari` for csr_matrix_times_vector
17+
* @note `csr_matrix_times_vector` uses the old inheritance
18+
* style to set up the reverse pass because of a linking
19+
* issue on windows when using flto.
20+
*
21+
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar
22+
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
23+
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar
24+
* type `var` or `double`. Or a `var<T>` where `T` inherits from
25+
* `Eigen::SparseBase`
26+
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type
27+
* `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
28+
*
29+
*/
30+
template <typename Result_, typename WMat_, typename B_>
31+
struct csr_adjoint : public vari {
32+
std::decay_t<Result_> res_;
33+
std::decay_t<WMat_> w_mat_;
34+
std::decay_t<B_> b_;
35+
36+
template <typename T1, typename T2, typename T3>
37+
csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
38+
: vari(0.0),
39+
res_(std::forward<T1>(res)),
40+
w_mat_(std::forward<T2>(w_mat)),
41+
b_(std::forward<T3>(b)) {}
42+
43+
void chain() { chain_internal(res_, w_mat_, b_); }
44+
45+
/**
46+
* Overload for calculating adjoints of `w_mat` and `b`
47+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
48+
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
49+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
50+
* type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
51+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
52+
* `var`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
53+
* @param res The vector result of the forward pass calculation
54+
* @param w_mat A sparse matrix
55+
* @param b A vector
56+
*/
57+
template <typename Result, typename WMat, typename B,
58+
require_rev_matrix_t<WMat>* = nullptr,
59+
require_rev_matrix_t<B>* = nullptr>
60+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
61+
w_mat.adj() += res.adj() * b.val().transpose();
62+
b.adj() += w_mat.val().transpose() * res.adj();
63+
}
64+
65+
/**
66+
* Overload for calculating adjoints of `w_mat`
67+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
68+
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
69+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
70+
* type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase`
71+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
72+
* `double`
73+
* @param res The vector result of the forward pass calculation
74+
* @param w_mat A sparse matrix
75+
* @param b A vector
76+
*/
77+
template <typename Result, typename WMat, typename B,
78+
require_rev_matrix_t<WMat>* = nullptr,
79+
require_not_rev_matrix_t<B>* = nullptr>
80+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
81+
w_mat.adj() += res.adj() * b.transpose();
82+
}
83+
84+
/**
85+
* Overload for calculating adjoints of `b`
86+
* @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar
87+
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
88+
* @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar
89+
* type `double`
90+
* @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type
91+
* `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
92+
* @param res The vector result of the forward pass calculation
93+
* @param w_mat A sparse matrix
94+
* @param b A vector
95+
*/
96+
template <typename Result, typename WMat, typename B,
97+
require_not_rev_matrix_t<WMat>* = nullptr,
98+
require_rev_matrix_t<B>* = nullptr>
99+
inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
100+
b.adj() += w_mat.transpose() * res.adj();
101+
}
102+
};
103+
104+
/**
105+
* Helper function to construct the csr_adjoint struct.
106+
* @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar
107+
* type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase`
108+
* @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar
109+
* type `var` or `double`. Or a `var<T>` where `T` inherits from
110+
* `Eigen::SparseBase`
111+
* @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type
112+
* `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase`
113+
*
114+
* @param res The vector result of the forward pass calculation
115+
* @param w_mat A sparse matrix
116+
* @param b A vector
117+
*/
118+
template <typename Result_, typename WMat_, typename B_>
119+
inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
120+
new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>(
121+
std::forward<Result_>(res), std::forward<WMat_>(w_mat),
122+
std::forward<B_>(b));
123+
return;
124+
}
125+
} // namespace internal
126+
14127
/**
15128
* \addtogroup csr_format
16129
* Return the multiplication of the sparse matrix (specified by
@@ -74,29 +187,22 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
74187
sparse_var_value_t w_mat_arena
75188
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
76189
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
77-
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
78-
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
79-
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
80-
});
190+
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
81191
return return_t(res);
82192
} else if (!is_constant<T2>::value) {
83193
arena_t<promote_scalar_t<var, T2>> b_arena = b;
84194
auto w_val_arena = to_arena(value_of(w));
85195
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
86196
v_arena.data(), w_val_arena.data());
87197
arena_t<return_t> res = w_val_mat * value_of(b_arena);
88-
reverse_pass_callback([w_val_mat, res, b_arena]() mutable {
89-
b_arena.adj() += w_val_mat.transpose() * res.adj();
90-
});
198+
stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena);
91199
return return_t(res);
92200
} else {
93201
sparse_var_value_t w_mat_arena
94202
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
95203
auto b_arena = to_arena(value_of(b));
96204
arena_t<return_t> res = w_mat_arena.val() * b_arena;
97-
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
98-
w_mat_arena.adj() += res.adj() * b_arena.transpose();
99-
});
205+
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
100206
return return_t(res);
101207
}
102208
}

0 commit comments

Comments
 (0)