Skip to content

Commit 86a3e83

Browse files
authored
Merge pull request #3048 from stan-dev/fix/csr-matrix-times-vector
Fix csr_matrix_times_vector linker error
2 parents 11663a2 + b0815c4 commit 86a3e83

File tree

12 files changed

+362
-86
lines changed

12 files changed

+362
-86
lines changed

stan/math/prim/fun/value_of.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ inline auto value_of(const T& x) {
6767
* @param[in] M Matrix to be converted
6868
* @return Matrix of values
6969
**/
70-
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
70+
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
7171
require_not_st_arithmetic<EigMat>* = nullptr>
7272
inline auto value_of(EigMat&& M) {
7373
return make_holder(
@@ -77,6 +77,28 @@ inline auto value_of(EigMat&& M) {
7777
std::forward<EigMat>(M));
7878
}
7979

80+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
81+
require_not_st_arithmetic<EigMat>* = nullptr>
82+
inline auto value_of(EigMat&& M) {
83+
auto&& M_ref = to_ref(M);
84+
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
85+
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
86+
M_ref.cols());
87+
ret.reserve(M_ref.nonZeros());
88+
for (int k = 0; k < M_ref.outerSize(); ++k) {
89+
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
90+
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
91+
}
92+
}
93+
ret.makeCompressed();
94+
return ret;
95+
}
96+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
97+
require_st_arithmetic<EigMat>* = nullptr>
98+
inline auto value_of(EigMat&& M) {
99+
return std::forward<EigMat>(M);
100+
}
101+
80102
} // namespace math
81103
} // namespace stan
82104

stan/math/prim/meta/is_eigen_dense_base.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ using require_eigen_dense_base_t
3333
= require_t<is_eigen_dense_base<std::decay_t<T>>>;
3434
/*! @} */
3535

36+
/*! \ingroup require_eigens_types */
37+
/*! \defgroup eigen_dense_base_types eigen_dense_base_types */
38+
/*! \addtogroup eigen_dense_base_types */
39+
/*! @{ */
40+
41+
/*! \brief Require type satisfies @ref is_eigen_dense_base */
42+
/*! and value type satisfies `TypeCheck` */
43+
/*! @tparam TypeCheck The type trait to check the value type against */
44+
/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose
45+
* @ref value_type is checked with `TypeCheck` */
46+
template <template <class...> class TypeCheck, class... Check>
47+
using require_eigen_dense_base_vt
48+
= require_t<container_type_check_base<is_eigen_dense_base, value_type_t,
49+
TypeCheck, Check...>>;
50+
/*! @} */
51+
3652
} // namespace stan
3753

3854
#endif

stan/math/prim/meta/promote_scalar_type.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta/is_eigen.hpp>
66
#include <stan/math/prim/meta/is_var.hpp>
7+
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
8+
#include <stan/math/prim/meta/is_eigen_sparse_base.hpp>
79
#include <vector>
810

911
namespace stan {
@@ -80,7 +82,7 @@ struct promote_scalar_type<T, S,
8082
* @tparam S input matrix type
8183
*/
8284
template <typename T, typename S>
83-
struct promote_scalar_type<T, S, require_eigen_t<S>> {
85+
struct promote_scalar_type<T, S, require_eigen_dense_base_t<S>> {
8486
/**
8587
* The promoted type.
8688
*/
@@ -93,6 +95,16 @@ struct promote_scalar_type<T, S, require_eigen_t<S>> {
9395
S::RowsAtCompileTime, S::ColsAtCompileTime>>::type;
9496
};
9597

98+
template <typename T, typename S>
99+
struct promote_scalar_type<T, S, require_eigen_sparse_base_t<S>> {
100+
/**
101+
* The promoted type.
102+
*/
103+
using type = Eigen::SparseMatrix<
104+
typename promote_scalar_type<T, typename S::Scalar>::type, S::Options,
105+
typename S::StorageIndex>;
106+
};
107+
96108
template <typename... PromotionScalars, typename... UnPromotedTypes>
97109
struct promote_scalar_type<std::tuple<PromotionScalars...>,
98110
std::tuple<UnPromotedTypes...>> {

stan/math/rev/core/arena_matrix.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <stan/math/rev/core/chainablestack.hpp>
77
#include <stan/math/rev/core/var_value_fwd_declare.hpp>
88
#include <stan/math/prim/fun/to_ref.hpp>
9-
109
namespace stan {
1110
namespace math {
1211

@@ -225,8 +224,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
225224
*/
226225
arena_matrix(const arena_matrix<MatrixType>& other)
227226
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
228-
other.outerIndexPtr(), other.innerIndexPtr(),
229-
other.valuePtr(), other.innernonZeroPtr()) {}
227+
const_cast<StorageIndex*>(other.outerIndexPtr()),
228+
const_cast<StorageIndex*>(other.innerIndexPtr()),
229+
const_cast<Scalar*>(other.valuePtr()),
230+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
230231
/**
231232
* Move constructor.
232233
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -235,8 +236,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
235236
*/
236237
arena_matrix(arena_matrix<MatrixType>&& other)
237238
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
238-
other.outerIndexPtr(), other.innerIndexPtr(),
239-
other.valuePtr(), other.innerNonZeroPtr()) {}
239+
const_cast<StorageIndex*>(other.outerIndexPtr()),
240+
const_cast<StorageIndex*>(other.innerIndexPtr()),
241+
const_cast<Scalar*>(other.valuePtr()),
242+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
240243
/**
241244
* Copy constructor. No actual copy is performed
242245
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -245,8 +248,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
245248
*/
246249
arena_matrix(arena_matrix<MatrixType>& other)
247250
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
248-
other.outerIndexPtr(), other.innerIndexPtr(),
249-
other.valuePtr(), other.innerNonZeroPtr()) {}
251+
const_cast<StorageIndex*>(other.outerIndexPtr()),
252+
const_cast<StorageIndex*>(other.innerIndexPtr()),
253+
const_cast<Scalar*>(other.valuePtr()),
254+
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
250255

251256
// without this using, compiler prefers combination of implicit construction
252257
// and copy assignment to the inherited operator when assigned an expression
@@ -259,7 +264,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
259264
* @return `*this`
260265
*/
261266
template <typename ArenaMatrix,
262-
require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr>
267+
require_same_t<std::decay_t<ArenaMatrix>,
268+
arena_matrix<MatrixType>>* = nullptr>
263269
arena_matrix& operator=(ArenaMatrix&& other) {
264270
// placement new changes what data map points to - there is no allocation
265271
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
@@ -280,7 +286,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
280286
template <typename Expr,
281287
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
282288
arena_matrix& operator=(Expr&& expr) {
283-
*this = arena_matrix(std::forward<Expr>(expr));
289+
new (this) arena_matrix(std::forward<Expr>(expr));
284290
return *this;
285291
}
286292

stan/math/rev/core/var.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,18 @@ class var_value<T, internal::require_matrix_var_value<T>> {
418418
});
419419
}
420420

421+
/**
422+
* Construct a `var_value` with premade @ref arena_matrix types.
423+
* The values and adjoint matrices passed here will be shallow copied.
424+
* @tparam S type of the value in the `var_value` to assing
425+
* @param val The value matrix to go into the vari
426+
* @param adj the adjoint matrix to go into the vari
427+
*/
428+
template <typename S, typename T_ = T,
429+
require_assignable_t<value_type, S>* = nullptr,
430+
require_arena_matrix_t<S>* = nullptr>
431+
var_value(const S& val, const S& adj) : vi_(new vari_type(val, adj)) {}
432+
421433
/**
422434
* Construct a variable from a pointer to a variable implementation.
423435
* @param vi A vari_value pointer.

stan/math/rev/core/vari.hpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -821,17 +821,16 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
821821
*/
822822
static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;
823823

824+
/**
825+
* The value of this variable.
826+
*/
827+
arena_matrix<PlainObject> val_;
824828
/**
825829
* The adjoint of this variable, which is the partial derivative
826830
* of this variable with respect to the root variable.
827831
*/
828832
arena_matrix<PlainObject> adj_;
829833

830-
/**
831-
* The value of this variable.
832-
*/
833-
arena_matrix<PlainObject> val_;
834-
835834
/**
836835
* Construct a variable implementation from a value. The
837836
* adjoint is initialized to zero.
@@ -847,10 +846,21 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
847846
* @param x Value of the constructed variable.
848847
*/
849848
template <typename S, require_convertible_t<S&, T>* = nullptr>
850-
explicit vari_value(S&& x) : adj_(x), val_(std::forward<S>(x)) {
851-
this->set_zero_adjoint();
849+
explicit vari_value(S&& x)
850+
: val_(std::forward<S>(x)),
851+
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
852+
val_.innerIndexPtr(),
853+
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
854+
val_.innerNonZeroPtr()) {
855+
ChainableStack::instance_->var_stack_.push_back(this);
856+
}
857+
858+
vari_value(const arena_matrix<PlainObject>& val,
859+
const arena_matrix<PlainObject>& adj)
860+
: val_(val), adj_(adj) {
852861
ChainableStack::instance_->var_stack_.push_back(this);
853862
}
863+
854864
/**
855865
* Construct an sparse Eigen variable implementation from a value. The
856866
* adjoint is initialized to zero and if `stacked` is `false` this vari
@@ -869,8 +879,12 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
869879
* that its `chain()` method is not called.
870880
*/
871881
template <typename S, require_convertible_t<S&, T>* = nullptr>
872-
vari_value(S&& x, bool stacked) : adj_(x), val_(std::forward<S>(x)) {
873-
this->set_zero_adjoint();
882+
vari_value(S&& x, bool stacked)
883+
: val_(std::forward<S>(x)),
884+
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
885+
val_.innerIndexPtr(),
886+
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
887+
val_.innerNonZeroPtr()) {
874888
if (stacked) {
875889
ChainableStack::instance_->var_stack_.push_back(this);
876890
} else {

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@
183183
#include <stan/math/rev/fun/tgamma.hpp>
184184
#include <stan/math/rev/fun/to_var.hpp>
185185
#include <stan/math/rev/fun/to_arena.hpp>
186+
#include <stan/math/rev/fun/to_soa_sparse_matrix.hpp>
186187
#include <stan/math/rev/fun/to_var_value.hpp>
187188
#include <stan/math/rev/fun/to_vector.hpp>
188189
#include <stan/math/rev/fun/trace.hpp>

stan/math/rev/fun/csr_matrix_times_vector.hpp

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/rev/fun/to_soa_sparse_matrix.hpp>
56
#include <stan/math/rev/core.hpp>
67
#include <stan/math/prim/err.hpp>
78
#include <stan/math/prim/fun/csr_u_to_z.hpp>
@@ -10,40 +11,6 @@
1011
namespace stan {
1112
namespace math {
1213

13-
namespace internal {
14-
template <typename T1, typename T2, typename Res,
15-
require_eigen_t<T1>* = nullptr>
16-
void update_w(T1& w, int m, int n, std::vector<int, arena_allocator<int>>& u,
17-
std::vector<int, arena_allocator<int>>& v, T2&& b, Res&& res) {
18-
Eigen::Map<Eigen::SparseMatrix<var, Eigen::RowMajor>> w_mat(
19-
m, n, w.size(), u.data(), v.data(), w.data());
20-
for (int k = 0; k < w_mat.outerSize(); ++k) {
21-
for (Eigen::Map<Eigen::SparseMatrix<var, Eigen::RowMajor>>::InnerIterator
22-
it(w_mat, k);
23-
it; ++it) {
24-
it.valueRef().adj()
25-
+= res.adj().coeff(it.row()) * value_of(b).coeff(it.col());
26-
}
27-
}
28-
}
29-
30-
template <typename T1, typename T2, typename Res,
31-
require_var_matrix_t<T1>* = nullptr>
32-
void update_w(T1& w, int m, int n, std::vector<int, arena_allocator<int>>& u,
33-
std::vector<int, arena_allocator<int>>& v, T2&& b, Res&& res) {
34-
Eigen::Map<Eigen::SparseMatrix<double, Eigen::RowMajor>> w_mat(
35-
m, n, w.size(), u.data(), v.data(), w.adj().data());
36-
for (int k = 0; k < w_mat.outerSize(); ++k) {
37-
for (Eigen::Map<Eigen::SparseMatrix<double, Eigen::RowMajor>>::InnerIterator
38-
it(w_mat, k);
39-
it; ++it) {
40-
it.valueRef() += res.adj().coeff(it.row()) * value_of(b).coeff(it.col());
41-
}
42-
}
43-
}
44-
45-
} // namespace internal
46-
4714
/**
4815
* \addtogroup csr_format
4916
* Return the multiplication of the sparse matrix (specified by
@@ -100,46 +67,36 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
10067
std::vector<int, arena_allocator<int>> u_arena(u.size());
10168
std::transform(u.begin(), u.end(), u_arena.begin(),
10269
[](auto&& x) { return x - 1; });
70+
using sparse_var_value_t
71+
= var_value<Eigen::SparseMatrix<double, Eigen::RowMajor>>;
10372
if (!is_constant<T2>::value && !is_constant<T1>::value) {
10473
arena_t<promote_scalar_t<var, T2>> b_arena = b;
105-
arena_t<promote_scalar_t<var, T1>> w_arena = to_arena(w);
106-
auto w_val_arena = to_arena(value_of(w_arena));
107-
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
108-
v_arena.data(), w_val_arena.data());
109-
arena_t<return_t> res = w_val_mat * value_of(b_arena);
110-
reverse_pass_callback(
111-
[m, n, w_arena, w_val_arena, v_arena, u_arena, res, b_arena]() mutable {
112-
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
113-
v_arena.data(), w_val_arena.data());
114-
internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res);
115-
b_arena.adj() += w_val_mat.transpose() * res.adj();
116-
});
74+
sparse_var_value_t w_mat_arena
75+
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
76+
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
77+
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
78+
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
79+
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
80+
});
11781
return return_t(res);
11882
} else if (!is_constant<T2>::value) {
11983
arena_t<promote_scalar_t<var, T2>> b_arena = b;
12084
auto w_val_arena = to_arena(value_of(w));
12185
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
12286
v_arena.data(), w_val_arena.data());
123-
12487
arena_t<return_t> res = w_val_mat * value_of(b_arena);
125-
reverse_pass_callback(
126-
[m, n, w_val_arena, v_arena, u_arena, res, b_arena]() mutable {
127-
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
128-
v_arena.data(), w_val_arena.data());
129-
b_arena.adj() += w_val_mat.transpose() * res.adj();
130-
});
88+
reverse_pass_callback([w_val_mat, res, b_arena]() mutable {
89+
b_arena.adj() += w_val_mat.transpose() * res.adj();
90+
});
13191
return return_t(res);
13292
} else {
133-
arena_t<promote_scalar_t<var, T1>> w_arena = to_arena(w);
134-
auto&& w_val = eval(value_of(w_arena));
135-
sparse_val_mat w_val_mat(m, n, w_val.size(), u_arena.data(), v_arena.data(),
136-
w_val.data());
93+
sparse_var_value_t w_mat_arena
94+
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
13795
auto b_arena = to_arena(value_of(b));
138-
arena_t<return_t> res = w_val_mat * b_arena;
139-
reverse_pass_callback(
140-
[m, n, w_arena, v_arena, u_arena, res, b_arena]() mutable {
141-
internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res);
142-
});
96+
arena_t<return_t> res = w_mat_arena.val() * b_arena;
97+
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
98+
w_mat_arena.adj() += res.adj() * b_arena.transpose();
99+
});
143100
return return_t(res);
144101
}
145102
}

0 commit comments

Comments
 (0)