Skip to content

Commit dfd3949

Browse files
committed
cleanup docs, add innerNonZero() calls, speedup inplace ops for sparse matrix
1 parent 7635348 commit dfd3949

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

stan/math/rev/core/arena_matrix.hpp

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
176176

177177
private:
178178
template <typename T, typename Integral>
179-
auto* copy_vector(const T* ptr, Integral size) {
179+
inline T* copy_vector(const T* ptr, Integral size) {
180+
if (size == 0) return nullptr;
180181
T* ret = ChainableStack::instance_->memalloc_.alloc_array<T>(size);
181182
std::copy_n(ptr, size, ret);
182183
return ret;
@@ -192,7 +193,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
192193
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
193194
copy_vector(other.outerIndexPtr(), other.outerSize() + 1),
194195
copy_vector(other.innerIndexPtr(), other.nonZeros()),
195-
copy_vector(other.valuePtr(), other.nonZeros())) {}
196+
copy_vector(other.valuePtr(), other.nonZeros()),
197+
copy_vector(other.innerNonZeroPtr(), other.innerNonZeroPtr() == nullptr ? 0 : other.innerSize())) {}
196198

197199
/**
198200
* Constructs `arena_matrix` from an Eigen expression
@@ -220,7 +222,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
220222
arena_matrix(const arena_matrix<MatrixType>& other)
221223
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
222224
other.outerIndexPtr(), other.innerIndexPtr(),
223-
other.valuePtr()) {}
225+
other.valuePtr(), other.innernonZeroPtr()) {}
224226
/**
225227
* Move constructor.
226228
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -230,7 +232,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
230232
arena_matrix(arena_matrix<MatrixType>&& other)
231233
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
232234
other.outerIndexPtr(), other.innerIndexPtr(),
233-
other.valuePtr()) {}
235+
other.valuePtr(), other.innerNonZeroPtr()) {}
234236
/**
235237
* Copy constructor. No actual copy is performed
236238
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -240,7 +242,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
240242
arena_matrix(arena_matrix<MatrixType>& other)
241243
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
242244
other.outerIndexPtr(), other.innerIndexPtr(),
243-
other.valuePtr()) {}
245+
other.valuePtr(), other.innerNonZeroPtr()) {}
244246

245247
// without this using, compiler prefers combination of implicit construction
246248
// and copy assignment to the inherited operator when assigned an expression
@@ -259,7 +261,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
259261
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
260262
const_cast<StorageIndex*>(other.outerIndexPtr()),
261263
const_cast<StorageIndex*>(other.innerIndexPtr()),
262-
const_cast<Scalar*>(other.valuePtr()));
264+
const_cast<Scalar*>(other.valuePtr()),
265+
const_cast<StorageIndex*>(other.innerNonZeroPtr()));
263266
return *this;
264267
}
265268

@@ -289,8 +292,32 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
289292
* @param other The right hand side of the inplace operation
290293
*/
291294
template <typename F, typename Expr,
292-
require_convertible_t<Expr&, MatrixType>* = nullptr>
293-
void inplace_ops_impl(F&& f, Expr&& other) {
295+
require_convertible_t<Expr&, MatrixType>* = nullptr,
296+
require_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
297+
inline void inplace_ops_impl(F&& f, Expr&& other) {
298+
auto&& x = to_ref(other);
299+
auto* val_ptr = (*this).valuePtr();
300+
auto* x_val_ptr = x.valuePtr();
301+
const auto non_zeros = (*this).nonZeros();
302+
for (Eigen::Index i = 0; i < non_zeros; ++i) {
303+
f(val_ptr[i], x_val_ptr[i]);
304+
}
305+
}
306+
307+
/**
308+
* inplace operations functor for `Sparse (.*)= Sparse`.
309+
* @note This assumes that each matrix is of the same size and sparsity
310+
* pattern.
311+
* @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
312+
* method
313+
* @tparam Expr Type derived from `Eigen::SparseMatrixBase`
314+
* @param f Functor that performs the inplace operation
315+
* @param other The right hand side of the inplace operation
316+
*/
317+
template <typename F, typename Expr,
318+
require_convertible_t<Expr&, MatrixType>* = nullptr,
319+
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
320+
inline void inplace_ops_impl(F&& f, Expr&& other) {
294321
auto&& x = to_ref(other);
295322
for (int k = 0; k < (*this).outerSize(); ++k) {
296323
typename Base::InnerIterator it(*this, k);
@@ -314,7 +341,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
314341
template <typename F, typename Expr,
315342
require_not_convertible_t<Expr&, MatrixType>* = nullptr,
316343
require_eigen_dense_base_t<Expr>* = nullptr>
317-
void inplace_ops_impl(F&& f, Expr&& other) {
344+
inline void inplace_ops_impl(F&& f, Expr&& other) {
318345
auto&& x = to_ref(other);
319346
for (int k = 0; k < (*this).outerSize(); ++k) {
320347
typename Base::InnerIterator it(*this, k);
@@ -336,30 +363,29 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
336363
*/
337364
template <typename F, typename T,
338365
require_convertible_t<T&, Scalar>* = nullptr>
339-
void inplace_ops_impl(F&& f, T&& other) {
340-
for (int k = 0; k < (*this).outerSize(); ++k) {
341-
typename Base::InnerIterator it(*this, k);
342-
for (; static_cast<bool>(it); ++it) {
343-
f(it.valueRef(), other);
344-
}
366+
inline void inplace_ops_impl(F&& f, T&& other) {
367+
auto* val_ptr = (*this).valuePtr();
368+
const auto non_zeros = (*this).nonZeros();
369+
for (Eigen::Index i = 0; i < non_zeros; ++i) {
370+
f(val_ptr[i], other);
345371
}
346372
}
347373

348374
public:
349375
/**
350376
* Inplace operator `+=`
351377
* @note Caution! Inplace operators assume that either
352-
* 1. The right hand side sparse matrix has the same sparcity pattern
378+
* 1. The right hand side sparse matrix has the same sparsity pattern
353379
* 2. You only intend to add a scalar or dense matrix coefficients to the
354-
* nonzero values of `this` This is intended to be used within the reverse
380+
* nonzero values of `this`. This is intended to be used within the reverse
355381
* pass for accumulation of the adjoint and is built as such. Any other use
356-
* case should be be sure the above assumptions are satisfied.
382+
* case should be be sure that the above assumptions are satisfied.
357383
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
358384
* `Eigen::DenseMatrixBase` or a `Scalar`
359385
* @param other value to be added inplace to the matrix.
360386
*/
361387
template <typename T>
362-
arena_matrix& operator+=(T&& other) {
388+
inline arena_matrix& operator+=(T&& other) {
363389
inplace_ops_impl(
364390
[](auto&& x, auto&& y) {
365391
x += y;
@@ -371,18 +397,18 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
371397

372398
/**
373399
* Inplace operator `+=`
374-
* @note Caution!! Inplace operators assume that either
375-
* 1. The right hand side sparse matrix has the same sparcity pattern
400+
* @note Caution! Inplace operators assume that either
401+
* 1. The right hand side sparse matrix has the same sparsity pattern
376402
* 2. You only intend to add a scalar or dense matrix coefficients to the
377-
* nonzero values of `this` This is intended to be used within the reverse
403+
* nonzero values of `this`. This is intended to be used within the reverse
378404
* pass for accumulation of the adjoint and is built as such. Any other use
379-
* case should be be sure the above assumptions are satisfied.
405+
* case should be be sure that the above assumptions are satisfied.
380406
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
381407
* `Eigen::DenseMatrixBase` or a `Scalar`
382408
* @param other value to be added inplace to the matrix.
383409
*/
384410
template <typename T>
385-
arena_matrix& operator-=(T&& other) {
411+
inline arena_matrix& operator-=(T&& other) {
386412
inplace_ops_impl(
387413
[](auto&& x, auto&& y) {
388414
x -= y;

stan/math/rev/core/var_value_fwd_declare.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ class var_value;
99

1010
/**
1111
* Equivalent to `Eigen::Matrix`, except that the data is stored on AD stack.
12-
* That makes these objects triviali destructible and usable in `vari`s.
12+
* That makes these objects trivially destructible and usable in `vari`s.
1313
*
14-
* @tparam MatrixType Eigen matrix type this works as (`MatrixXd`, `VectorXd`
14+
* @tparam MatrixType Eigen matrix type this works as (`MatrixXd`, `VectorXd`,
1515
* ...)
1616
*/
1717
template <typename MatrixType, typename = void>

test/unit/math/rev/core/arena_matrix_test.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ TEST_F(AgradRev, arena_sparse_matrix_constructors) {
136136
using stan::test::make_sparse_matrix_random;
137137
eig_mat A = make_sparse_matrix_random(10, 10);
138138
// Testing each constructor
139-
arena_mat empty_m();
139+
arena_mat empty_m{};
140140
arena_mat nocopy(A.rows(), A.cols(), A.nonZeros(), A.outerIndexPtr(),
141141
A.innerIndexPtr(), A.valuePtr(), A.innerNonZeroPtr());
142142
Eigen::Map<eig_mat> sparse_map(A.rows(), A.cols(), A.nonZeros(),
@@ -208,6 +208,15 @@ TEST_F(AgradRev, arena_sparse_matrix_inplace_ops) {
208208
A_m -= A;
209209
expect_sparse_matrix_equal(A_m, eig_mat(A - A));
210210

211+
arena_mat B_m = A;
212+
A_m = A;
213+
A_m += B_m;
214+
expect_sparse_matrix_equal(A_m, eig_mat(A + A));
215+
216+
A_m = A;
217+
A_m -= B_m;
218+
expect_sparse_matrix_equal(A_m, eig_mat(A - A));
219+
211220
Eigen::Matrix<double, -1, -1> B
212221
= Eigen::Matrix<double, -1, -1>::Random(10, 10);
213222
A_m = A;

0 commit comments

Comments
 (0)