|
11 | 11 | namespace stan {
|
12 | 12 | namespace math {
|
13 | 13 |
|
| 14 | +namespace internal { |
| 15 | + template <typename Result_, typename WMat_, typename B_> |
| 16 | + struct csr_adjoint : public vari { |
| 17 | + std::decay_t<Result_> res_; |
| 18 | + std::decay_t<WMat_> w_mat_; |
| 19 | + std::decay_t<B_> b_; |
| 20 | + template <typename T1, typename T2, typename T3> |
| 21 | + csr_adjoint(T1&& res, T2&& w_mat, T3&& b) |
| 22 | + : vari(0.0), res_(std::forward<T1>(res)), |
| 23 | + w_mat_(std::forward<T2>(w_mat)), b_(std::forward<T3>(b)) {} |
| 24 | + |
| 25 | + void chain() { |
| 26 | + chain_internal(res_, w_mat_, b_); |
| 27 | + } |
| 28 | + template <typename Result, typename WMat, typename B, |
| 29 | + require_rev_matrix_t<WMat>* = nullptr, |
| 30 | + require_rev_matrix_t<B>* = nullptr> |
| 31 | + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 32 | + w_mat.adj() += res.adj() * b.val().transpose(); |
| 33 | + b.adj() += w_mat.val().transpose() * res.adj(); |
| 34 | + } |
| 35 | + |
| 36 | + template <typename Result, typename WMat, typename B, |
| 37 | + require_rev_matrix_t<WMat>* = nullptr, |
| 38 | + require_not_rev_matrix_t<B>* = nullptr> |
| 39 | + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 40 | + w_mat.adj() += res.adj() * b.transpose(); |
| 41 | + } |
| 42 | + |
| 43 | + template <typename Result, typename WMat, typename B, |
| 44 | + require_not_rev_matrix_t<WMat>* = nullptr, |
| 45 | + require_rev_matrix_t<B>* = nullptr> |
| 46 | + void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 47 | + b.adj() += w_mat.transpose() * res.adj(); |
| 48 | + } |
| 49 | + }; |
| 50 | + template <typename Result_, typename WMat_, typename B_> |
| 51 | + inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { |
| 52 | + return new csr_adjoint<Result_, WMat_, B_>( |
| 53 | + std::forward<Result_>(res), std::forward<WMat_>(w_mat), |
| 54 | + std::forward<B_>(b)); |
| 55 | + } |
| 56 | +} |
| 57 | + |
14 | 58 | /**
|
15 | 59 | * \addtogroup csr_format
|
16 | 60 | * Return the multiplication of the sparse matrix (specified by
|
@@ -74,29 +118,22 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
|
74 | 118 | sparse_var_value_t w_mat_arena
|
75 | 119 | = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
|
76 | 120 | arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
|
77 |
| - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { |
78 |
| - w_mat_arena.adj() += res.adj() * b_arena.val().transpose(); |
79 |
| - b_arena.adj() += w_mat_arena.val().transpose() * res.adj(); |
80 |
| - }); |
| 121 | + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); |
81 | 122 | return return_t(res);
|
82 | 123 | } else if (!is_constant<T2>::value) {
|
83 | 124 | arena_t<promote_scalar_t<var, T2>> b_arena = b;
|
84 | 125 | auto w_val_arena = to_arena(value_of(w));
|
85 | 126 | sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
|
86 | 127 | v_arena.data(), w_val_arena.data());
|
87 | 128 | arena_t<return_t> res = w_val_mat * value_of(b_arena);
|
88 |
| - reverse_pass_callback([w_val_mat, res, b_arena]() mutable { |
89 |
| - b_arena.adj() += w_val_mat.transpose() * res.adj(); |
90 |
| - }); |
| 129 | + stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena); |
91 | 130 | return return_t(res);
|
92 | 131 | } else {
|
93 | 132 | sparse_var_value_t w_mat_arena
|
94 | 133 | = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
|
95 | 134 | auto b_arena = to_arena(value_of(b));
|
96 | 135 | arena_t<return_t> res = w_mat_arena.val() * b_arena;
|
97 |
| - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { |
98 |
| - w_mat_arena.adj() += res.adj() * b_arena.transpose(); |
99 |
| - }); |
| 136 | + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); |
100 | 137 | return return_t(res);
|
101 | 138 | }
|
102 | 139 | }
|
|
0 commit comments