|
11 | 11 | namespace stan {
|
12 | 12 | namespace math {
|
13 | 13 |
|
| 14 | +namespace internal { |
| 15 | +/** |
| 16 | + * `vari` for csr_matrix_times_vector |
| 17 | + * @note `csr_matrix_times_vector` uses the old inheritance |
| 18 | + * style to set up the reverse pass because of a linking |
| 19 | + * issue on windows when using flto. |
| 20 | + * |
| 21 | + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar |
| 22 | + * type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 23 | + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar |
| 24 | + * type `var` or `double`. Or a `var<T>` where `T` inherits from |
| 25 | + * `Eigen::SparseBase` |
| 26 | + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type |
| 27 | + * `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 28 | + * |
| 29 | + */ |
| 30 | +template <typename Result_, typename WMat_, typename B_> |
| 31 | +struct csr_adjoint : public vari { |
| 32 | + std::decay_t<Result_> res_; |
| 33 | + std::decay_t<WMat_> w_mat_; |
| 34 | + std::decay_t<B_> b_; |
| 35 | + |
| 36 | + template <typename T1, typename T2, typename T3> |
| 37 | + csr_adjoint(T1&& res, T2&& w_mat, T3&& b) |
| 38 | + : vari(0.0), |
| 39 | + res_(std::forward<T1>(res)), |
| 40 | + w_mat_(std::forward<T2>(w_mat)), |
| 41 | + b_(std::forward<T3>(b)) {} |
| 42 | + |
| 43 | + void chain() { chain_internal(res_, w_mat_, b_); } |
| 44 | + |
| 45 | + /** |
| 46 | + * Overload for calculating adjoints of `w_mat` and `b` |
| 47 | + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar |
| 48 | + * type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 49 | + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar |
| 50 | + * type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase` |
| 51 | + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type |
| 52 | + * `var`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 53 | + * @param res The vector result of the forward pass calculation |
| 54 | + * @param w_mat A sparse matrix |
| 55 | + * @param b A vector |
| 56 | + */ |
| 57 | + template <typename Result, typename WMat, typename B, |
| 58 | + require_rev_matrix_t<WMat>* = nullptr, |
| 59 | + require_rev_matrix_t<B>* = nullptr> |
| 60 | + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 61 | + w_mat.adj() += res.adj() * b.val().transpose(); |
| 62 | + b.adj() += w_mat.val().transpose() * res.adj(); |
| 63 | + } |
| 64 | + |
| 65 | + /** |
| 66 | + * Overload for calculating adjoints of `w_mat` |
| 67 | + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar |
| 68 | + * type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 69 | + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar |
| 70 | + * type `var`. Or a `var<T>` where `T` inherits from `Eigen::SparseBase` |
| 71 | + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type |
| 72 | + * `double` |
| 73 | + * @param res The vector result of the forward pass calculation |
| 74 | + * @param w_mat A sparse matrix |
| 75 | + * @param b A vector |
| 76 | + */ |
| 77 | + template <typename Result, typename WMat, typename B, |
| 78 | + require_rev_matrix_t<WMat>* = nullptr, |
| 79 | + require_not_rev_matrix_t<B>* = nullptr> |
| 80 | + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 81 | + w_mat.adj() += res.adj() * b.transpose(); |
| 82 | + } |
| 83 | + |
| 84 | + /** |
| 85 | + * Overload for calculating adjoints of `b` |
| 86 | + * @tparam Result Either a type inheriting from `Eigen::DenseBase` with scalar |
| 87 | + * type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 88 | + * @tparam WMat Either a type inheriting from `Eigen::DenseBase` with scalar |
| 89 | + * type `double` |
| 90 | + * @tparam B Either a type inheriting from `Eigen::DenseBase` with scalar type |
| 91 | + * `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 92 | + * @param res The vector result of the forward pass calculation |
| 93 | + * @param w_mat A sparse matrix |
| 94 | + * @param b A vector |
| 95 | + */ |
| 96 | + template <typename Result, typename WMat, typename B, |
| 97 | + require_not_rev_matrix_t<WMat>* = nullptr, |
| 98 | + require_rev_matrix_t<B>* = nullptr> |
| 99 | + inline void chain_internal(Result&& res, WMat&& w_mat, B&& b) { |
| 100 | + b.adj() += w_mat.transpose() * res.adj(); |
| 101 | + } |
| 102 | +}; |
| 103 | + |
| 104 | +/** |
| 105 | + * Helper function to construct the csr_adjoint struct. |
| 106 | + * @tparam Result_ Either a type inheriting from `Eigen::DenseBase` with scalar |
| 107 | + * type `var` or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 108 | + * @tparam WMat_ Either a type inheriting from `Eigen::DenseBase` with scalar |
| 109 | + * type `var` or `double`. Or a `var<T>` where `T` inherits from |
| 110 | + * `Eigen::SparseBase` |
| 111 | + * @tparam B_ Either a type inheriting from `Eigen::DenseBase` with scalar type |
| 112 | + * `var` or `double`. Or a `var<T>` where `T` inherits from `Eigen::DenseBase` |
| 113 | + * |
| 114 | + * @param res The vector result of the forward pass calculation |
| 115 | + * @param w_mat A sparse matrix |
| 116 | + * @param b A vector |
| 117 | + */ |
| 118 | +template <typename Result_, typename WMat_, typename B_> |
| 119 | +inline void make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) { |
| 120 | + new csr_adjoint<std::decay_t<Result_>, std::decay_t<WMat_>, std::decay_t<B_>>( |
| 121 | + std::forward<Result_>(res), std::forward<WMat_>(w_mat), |
| 122 | + std::forward<B_>(b)); |
| 123 | + return; |
| 124 | +} |
| 125 | +} // namespace internal |
| 126 | + |
14 | 127 | /**
|
15 | 128 | * \addtogroup csr_format
|
16 | 129 | * Return the multiplication of the sparse matrix (specified by
|
@@ -74,29 +187,22 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
|
74 | 187 | sparse_var_value_t w_mat_arena
|
75 | 188 | = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
|
76 | 189 | arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
|
77 |
| - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { |
78 |
| - w_mat_arena.adj() += res.adj() * b_arena.val().transpose(); |
79 |
| - b_arena.adj() += w_mat_arena.val().transpose() * res.adj(); |
80 |
| - }); |
| 190 | + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); |
81 | 191 | return return_t(res);
|
82 | 192 | } else if (!is_constant<T2>::value) {
|
83 | 193 | arena_t<promote_scalar_t<var, T2>> b_arena = b;
|
84 | 194 | auto w_val_arena = to_arena(value_of(w));
|
85 | 195 | sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
|
86 | 196 | v_arena.data(), w_val_arena.data());
|
87 | 197 | arena_t<return_t> res = w_val_mat * value_of(b_arena);
|
88 |
| - reverse_pass_callback([w_val_mat, res, b_arena]() mutable { |
89 |
| - b_arena.adj() += w_val_mat.transpose() * res.adj(); |
90 |
| - }); |
| 198 | + stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena); |
91 | 199 | return return_t(res);
|
92 | 200 | } else {
|
93 | 201 | sparse_var_value_t w_mat_arena
|
94 | 202 | = to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
|
95 | 203 | auto b_arena = to_arena(value_of(b));
|
96 | 204 | arena_t<return_t> res = w_mat_arena.val() * b_arena;
|
97 |
| - reverse_pass_callback([res, w_mat_arena, b_arena]() mutable { |
98 |
| - w_mat_arena.adj() += res.adj() * b_arena.transpose(); |
99 |
| - }); |
| 205 | + stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena); |
100 | 206 | return return_t(res);
|
101 | 207 | }
|
102 | 208 | }
|
|
0 commit comments