Skip to content

Commit 5c5dfc6

Browse files
committed
update csr_matrix_times_vector w adjoint update. Uncomment tests for to_soa_sparse_matrix
1 parent 33f0825 commit 5c5dfc6

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

stan/math/rev/fun/csr_matrix_times_vector.hpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,7 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
7575
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
7676
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
7777
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
78-
for (int k = 0; k < w_mat_arena.adj().outerSize(); ++k) {
79-
for (typename sparse_var_value_t::vari_type::InnerIterator it(
80-
w_mat_arena.adj(), k);
81-
it; ++it) {
82-
it.valueRef()
83-
+= res.adj().coeff(it.row()) * b_arena.val().coeff(it.col());
84-
}
85-
}
78+
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
8679
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
8780
});
8881
return return_t(res);
@@ -102,13 +95,7 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
10295
auto b_arena = to_arena(value_of(b));
10396
arena_t<return_t> res = w_mat_arena.val() * b_arena;
10497
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
105-
for (int k = 0; k < w_mat_arena.adj().outerSize(); ++k) {
106-
for (typename sparse_var_value_t::vari_type::InnerIterator it(
107-
w_mat_arena.adj(), k);
108-
it; ++it) {
109-
it.valueRef() += res.adj().coeff(it.row()) * b_arena.coeff(it.col());
110-
}
111-
}
98+
w_mat_arena.adj() += res.adj() * b_arena.transpose();
11299
});
113100
return return_t(res);
114101
}

test/unit/math/rev/fun/to_soa_sparse_matrix_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <stan/math/rev/core.hpp>
55
#include <test/unit/math/rev/util.hpp>
66
#include <vector>
7-
/*
7+
88
TEST_F(AgradRev, to_soa_sparse_matrix_matrix_double) {
99
using stan::math::var;
1010
using stan::math::var_value;
@@ -40,7 +40,7 @@ TEST_F(AgradRev, to_soa_sparse_matrix_matrix_var) {
4040
EXPECT_EQ(w_mat_arena.val().valuePtr()[i], w.val()(i));
4141
}
4242
}
43-
*/
43+
4444

4545
TEST_F(AgradRev, to_soa_sparse_matrix_var_matrix) {
4646
using stan::math::to_soa_sparse_matrix;

0 commit comments

Comments
 (0)