Skip to content

Commit 34cf554

Browse files
committed
use a seperate class for csr_matrix adjoint
1 parent 08d8a22 commit 34cf554

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

stan/math/rev/fun/csr_matrix_times_vector.hpp

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,50 @@
1111
namespace stan {
1212
namespace math {
1313

14+
namespace internal {
15+
template <typename Result_, typename WMat_, typename B_>
16+
struct csr_adjoint : public vari {
17+
std::decay_t<Result_> res_;
18+
std::decay_t<WMat_> w_mat_;
19+
std::decay_t<B_> b_;
20+
template <typename T1, typename T2, typename T3>
21+
csr_adjoint(T1&& res, T2&& w_mat, T3&& b)
22+
: vari(0.0), res_(std::forward<T1>(res)),
23+
w_mat_(std::forward<T2>(w_mat)), b_(std::forward<T3>(b)) {}
24+
25+
void chain() {
26+
chain_internal(res_, w_mat_, b_);
27+
}
28+
template <typename Result, typename WMat, typename B,
29+
require_rev_matrix_t<WMat>* = nullptr,
30+
require_rev_matrix_t<B>* = nullptr>
31+
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
32+
w_mat.adj() += res.adj() * b.val().transpose();
33+
b.adj() += w_mat.val().transpose() * res.adj();
34+
}
35+
36+
template <typename Result, typename WMat, typename B,
37+
require_rev_matrix_t<WMat>* = nullptr,
38+
require_not_rev_matrix_t<B>* = nullptr>
39+
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
40+
w_mat.adj() += res.adj() * b.transpose();
41+
}
42+
43+
template <typename Result, typename WMat, typename B,
44+
require_not_rev_matrix_t<WMat>* = nullptr,
45+
require_rev_matrix_t<B>* = nullptr>
46+
void chain_internal(Result&& res, WMat&& w_mat, B&& b) {
47+
b.adj() += w_mat.transpose() * res.adj();
48+
}
49+
};
50+
template <typename Result_, typename WMat_, typename B_>
51+
inline vari* make_csr_adjoint(Result_&& res, WMat_&& w_mat, B_&& b) {
52+
return new csr_adjoint<Result_, WMat_, B_>(
53+
std::forward<Result_>(res), std::forward<WMat_>(w_mat),
54+
std::forward<B_>(b));
55+
}
56+
}
57+
1458
/**
1559
* \addtogroup csr_format
1660
* Return the multiplication of the sparse matrix (specified by
@@ -74,29 +118,22 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
74118
sparse_var_value_t w_mat_arena
75119
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
76120
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
77-
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
78-
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
79-
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
80-
});
121+
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
81122
return return_t(res);
82123
} else if (!is_constant<T2>::value) {
83124
arena_t<promote_scalar_t<var, T2>> b_arena = b;
84125
auto w_val_arena = to_arena(value_of(w));
85126
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
86127
v_arena.data(), w_val_arena.data());
87128
arena_t<return_t> res = w_val_mat * value_of(b_arena);
88-
reverse_pass_callback([w_val_mat, res, b_arena]() mutable {
89-
b_arena.adj() += w_val_mat.transpose() * res.adj();
90-
});
129+
stan::math::internal::make_csr_adjoint(res, w_val_mat, b_arena);
91130
return return_t(res);
92131
} else {
93132
sparse_var_value_t w_mat_arena
94133
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
95134
auto b_arena = to_arena(value_of(b));
96135
arena_t<return_t> res = w_mat_arena.val() * b_arena;
97-
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
98-
w_mat_arena.adj() += res.adj() * b_arena.transpose();
99-
});
136+
stan::math::internal::make_csr_adjoint(res, w_mat_arena, b_arena);
100137
return return_t(res);
101138
}
102139
}

0 commit comments

Comments
 (0)