@@ -12,48 +12,48 @@ namespace stan {
12
12
namespace math {
13
13
14
14
namespace internal {
15
- template <typename Result_, typename WMat_, typename B_>
16
- struct csr_adjoint : public vari {
17
- std::decay_t <Result_> res_;
18
- std::decay_t <WMat_> w_mat_;
19
- std::decay_t <B_> b_;
20
- template <typename T1, typename T2, typename T3>
21
- csr_adjoint (T1&& res, T2&& w_mat, T3&& b)
22
- : vari(0.0 ), res_(std::forward<T1>(res)),
23
- w_mat_ (std::forward<T2>(w_mat)), b_(std::forward<T3>(b)) {}
15
+ template <typename Result_, typename WMat_, typename B_>
16
+ struct csr_adjoint : public vari {
17
+ std::decay_t <Result_> res_;
18
+ std::decay_t <WMat_> w_mat_;
19
+ std::decay_t <B_> b_;
20
+ template <typename T1, typename T2, typename T3>
21
+ csr_adjoint (T1&& res, T2&& w_mat, T3&& b)
22
+ : vari(0.0 ),
23
+ res_ (std::forward<T1>(res)),
24
+ w_mat_(std::forward<T2>(w_mat)),
25
+ b_(std::forward<T3>(b)) {}
24
26
25
- void chain () {
26
- chain_internal (res_, w_mat_, b_);
27
- }
28
- template <typename Result, typename WMat, typename B,
29
- require_rev_matrix_t <WMat>* = nullptr ,
30
- require_rev_matrix_t <B>* = nullptr >
31
- void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
32
- w_mat.adj () += res.adj () * b.val ().transpose ();
33
- b.adj () += w_mat.val ().transpose () * res.adj ();
34
- }
27
+ void chain () { chain_internal (res_, w_mat_, b_); }
28
+ template <typename Result, typename WMat, typename B,
29
+ require_rev_matrix_t <WMat>* = nullptr ,
30
+ require_rev_matrix_t <B>* = nullptr >
31
+ void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
32
+ w_mat.adj () += res.adj () * b.val ().transpose ();
33
+ b.adj () += w_mat.val ().transpose () * res.adj ();
34
+ }
35
35
36
- template <typename Result, typename WMat, typename B,
37
- require_rev_matrix_t <WMat>* = nullptr ,
38
- require_not_rev_matrix_t <B>* = nullptr >
39
- void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
40
- w_mat.adj () += res.adj () * b.transpose ();
41
- }
36
+ template <typename Result, typename WMat, typename B,
37
+ require_rev_matrix_t <WMat>* = nullptr ,
38
+ require_not_rev_matrix_t <B>* = nullptr >
39
+ void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
40
+ w_mat.adj () += res.adj () * b.transpose ();
41
+ }
42
42
43
- template <typename Result, typename WMat, typename B,
44
- require_not_rev_matrix_t <WMat>* = nullptr ,
45
- require_rev_matrix_t <B>* = nullptr >
46
- void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
47
- b.adj () += w_mat.transpose () * res.adj ();
48
- }
49
- };
50
- template <typename Result_, typename WMat_, typename B_>
51
- inline vari* make_csr_adjoint (Result_&& res, WMat_&& w_mat, B_&& b) {
52
- return new csr_adjoint<Result_, WMat_, B_>(
53
- std::forward<Result_>(res), std::forward<WMat_>(w_mat),
54
- std::forward<B_>(b));
43
+ template <typename Result, typename WMat, typename B,
44
+ require_not_rev_matrix_t <WMat>* = nullptr ,
45
+ require_rev_matrix_t <B>* = nullptr >
46
+ void chain_internal (Result&& res, WMat&& w_mat, B&& b) {
47
+ b.adj () += w_mat.transpose () * res.adj ();
55
48
}
49
+ };
50
+ template <typename Result_, typename WMat_, typename B_>
51
+ inline vari* make_csr_adjoint (Result_&& res, WMat_&& w_mat, B_&& b) {
52
+ return new csr_adjoint<Result_, WMat_, B_>(std::forward<Result_>(res),
53
+ std::forward<WMat_>(w_mat),
54
+ std::forward<B_>(b));
56
55
}
56
+ } // namespace internal
57
57
58
58
/* *
59
59
* \addtogroup csr_format
0 commit comments