Skip to content

Commit 35d6d53

Browse files
authored
Merge pull request #2971 from stan-dev/fix/sparse-vari-impl
Add arena_matrix specialization for sparse matrices
2 parents 87d04dc + 11e414e commit 35d6d53

File tree

10 files changed

+471
-49
lines changed

10 files changed

+471
-49
lines changed

stan/math/rev/core/arena_matrix.hpp

Lines changed: 292 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/rev/core/chainable_alloc.hpp>
66
#include <stan/math/rev/core/chainablestack.hpp>
7+
#include <stan/math/rev/core/var_value_fwd_declare.hpp>
8+
#include <stan/math/prim/fun/to_ref.hpp>
79

810
namespace stan {
911
namespace math {
1012

11-
/**
12-
* Equivalent to `Eigen::Matrix`, except that the data is stored on AD stack.
13-
* That makes these objects triviali destructible and usable in `vari`s.
14-
*
15-
* @tparam MatrixType Eigen matrix type this works as (`MatrixXd`, `VectorXd`
16-
* ...)
17-
*/
1813
template <typename MatrixType>
19-
class arena_matrix : public Eigen::Map<MatrixType> {
14+
class arena_matrix<MatrixType, require_eigen_dense_base_t<MatrixType>>
15+
: public Eigen::Map<MatrixType> {
2016
public:
2117
using Scalar = value_type_t<MatrixType>;
2218
using Base = Eigen::Map<MatrixType>;
@@ -139,6 +135,294 @@ class arena_matrix : public Eigen::Map<MatrixType> {
139135
}
140136
};
141137

138+
template <typename MatrixType>
139+
class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
140+
: public Eigen::Map<MatrixType> {
141+
public:
142+
using Scalar = value_type_t<MatrixType>;
143+
using Base = Eigen::Map<MatrixType>;
144+
using PlainObject = std::decay_t<MatrixType>;
145+
using StorageIndex = typename PlainObject::StorageIndex;
146+
147+
/**
148+
* Default constructor.
149+
*/
150+
arena_matrix() : Base::Map(0, 0, 0, nullptr, nullptr, nullptr) {}
151+
/**
152+
* Constructor for CDC formatted data. Assumes compressed and does not own
153+
* memory.
154+
*
155+
* Constructs a read-write Map to a sparse matrix of size rows x cols,
156+
* containing nnz non-zero coefficients, stored as a sparse format as defined
157+
* by the pointers outerIndexPtr, innerIndexPtr, and valuePtr. If the optional
158+
* parameter innerNonZerosPtr is the null pointer, then a standard compressed
159+
* format is assumed.
160+
*
161+
* @param rows Number of rows
162+
* @param cols Number of columns
163+
* @param nnz Number of nonzero items in matrix
164+
* @param outerIndexPtr A pointer to the array of outer indices.
165+
* @param innerIndexPtr A pointer to the array of inner indices.
166+
* @param valuePtr A pointer to the array of values.
167+
* @param innerNonZerosPtr A pointer to the array of the number of non zeros
168+
* of the inner vectors.
169+
*
170+
*/
171+
arena_matrix(Eigen::Index rows, Eigen::Index cols, Eigen::Index nnz,
172+
StorageIndex* outerIndexPtr, StorageIndex* innerIndexPtr,
173+
Scalar* valuePtr, StorageIndex* innerNonZerosPtr = nullptr)
174+
: Base::Map(rows, cols, nnz, outerIndexPtr, innerIndexPtr, valuePtr,
175+
innerNonZerosPtr) {}
176+
177+
private:
178+
template <typename T, typename Integral>
179+
inline T* copy_vector(const T* ptr, Integral size) {
180+
if (size == 0)
181+
return nullptr;
182+
T* ret = ChainableStack::instance_->memalloc_.alloc_array<T>(size);
183+
std::copy_n(ptr, size, ret);
184+
return ret;
185+
}
186+
187+
public:
188+
/**
189+
* Constructs `arena_matrix` from an `Eigen::SparseMatrix`.
190+
* @param other Eigen Sparse Matrix class
191+
*/
192+
template <typename T, require_same_t<T, PlainObject>* = nullptr>
193+
arena_matrix(T&& other) // NOLINT
194+
: Base::Map(
195+
other.rows(), other.cols(), other.nonZeros(),
196+
copy_vector(other.outerIndexPtr(), other.outerSize() + 1),
197+
copy_vector(other.innerIndexPtr(), other.nonZeros()),
198+
copy_vector(other.valuePtr(), other.nonZeros()),
199+
copy_vector(
200+
other.innerNonZeroPtr(),
201+
other.innerNonZeroPtr() == nullptr ? 0 : other.innerSize())) {}
202+
203+
/**
204+
* Constructs `arena_matrix` from an Eigen expression
205+
* @param other An expression
206+
*/
207+
template <typename S, require_convertible_t<S&, PlainObject>* = nullptr,
208+
require_not_same_t<S, PlainObject>* = nullptr>
209+
arena_matrix(S&& other) // NOLINT
210+
: arena_matrix(PlainObject(std::forward<S>(other))) {}
211+
212+
/**
213+
* Constructs `arena_matrix` from an expression. This makes an assumption that
214+
* any other `Eigen::Map` also contains memory allocated in the arena.
215+
* @param other expression
216+
*/
217+
arena_matrix(const Base& other) // NOLINT
218+
: Base(other) {}
219+
220+
/**
221+
* Const copy constructor. No actual copy is performed
222+
* @note Since the memory for the arena matrix sits in Stan's memory arena all
223+
* copies/moves of arena matrices are shallow moves
224+
* @param other matrix to copy from
225+
*/
226+
arena_matrix(const arena_matrix<MatrixType>& other)
227+
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
228+
other.outerIndexPtr(), other.innerIndexPtr(),
229+
other.valuePtr(), other.innernonZeroPtr()) {}
230+
/**
231+
* Move constructor.
232+
* @note Since the memory for the arena matrix sits in Stan's memory arena all
233+
* copies/moves of arena matrices are shallow moves
234+
* @param other matrix to copy from
235+
*/
236+
arena_matrix(arena_matrix<MatrixType>&& other)
237+
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
238+
other.outerIndexPtr(), other.innerIndexPtr(),
239+
other.valuePtr(), other.innerNonZeroPtr()) {}
240+
/**
241+
* Copy constructor. No actual copy is performed
242+
* @note Since the memory for the arena matrix sits in Stan's memory arena all
243+
* copies/moves of arena matrices are shallow moves
244+
* @param other matrix to copy from
245+
*/
246+
arena_matrix(arena_matrix<MatrixType>& other)
247+
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
248+
other.outerIndexPtr(), other.innerIndexPtr(),
249+
other.valuePtr(), other.innerNonZeroPtr()) {}
250+
251+
// without this using, compiler prefers combination of implicit construction
252+
// and copy assignment to the inherited operator when assigned an expression
253+
using Base::operator=;
254+
255+
/**
256+
* Copy assignment operator.
257+
* @tparam ArenaMatrix An `arena_matrix` type
258+
* @param other matrix to copy from
259+
* @return `*this`
260+
*/
261+
template <typename ArenaMatrix,
262+
require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr>
263+
arena_matrix& operator=(ArenaMatrix&& other) {
264+
// placement new changes what data map points to - there is no allocation
265+
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
266+
const_cast<StorageIndex*>(other.outerIndexPtr()),
267+
const_cast<StorageIndex*>(other.innerIndexPtr()),
268+
const_cast<Scalar*>(other.valuePtr()),
269+
const_cast<StorageIndex*>(other.innerNonZeroPtr()));
270+
return *this;
271+
}
272+
273+
/**
274+
* Assignment operator for assigning an expression.
275+
* @tparam Expr An expression that an `arena_matrix<MatrixType>` can be
276+
* constructed from
277+
* @param expr expression to evaluate into this
278+
* @return `*this`
279+
*/
280+
template <typename Expr,
281+
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
282+
arena_matrix& operator=(Expr&& expr) {
283+
*this = arena_matrix(std::forward<Expr>(expr));
284+
return *this;
285+
}
286+
287+
private:
288+
/**
289+
* inplace operations functor for `Sparse (.*)= Sparse`.
290+
* @note This assumes that each matrix is of the same size and sparsity
291+
* pattern.
292+
* @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
293+
* method
294+
* @tparam Expr Type derived from `Eigen::SparseMatrixBase`
295+
* @param f Functor that performs the inplace operation
296+
* @param other The right hand side of the inplace operation
297+
*/
298+
template <typename F, typename Expr,
299+
require_convertible_t<Expr&, MatrixType>* = nullptr,
300+
require_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
301+
inline void inplace_ops_impl(F&& f, Expr&& other) {
302+
auto&& x = to_ref(other);
303+
auto* val_ptr = (*this).valuePtr();
304+
auto* x_val_ptr = x.valuePtr();
305+
const auto non_zeros = (*this).nonZeros();
306+
for (Eigen::Index i = 0; i < non_zeros; ++i) {
307+
f(val_ptr[i], x_val_ptr[i]);
308+
}
309+
}
310+
311+
/**
312+
* inplace operations functor for `Sparse (.*)= Sparse`.
313+
* @note This assumes that each matrix is of the same size and sparsity
314+
* pattern.
315+
* @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
316+
* method
317+
* @tparam Expr Type derived from `Eigen::SparseMatrixBase`
318+
* @param f Functor that performs the inplace operation
319+
* @param other The right hand side of the inplace operation
320+
*/
321+
template <typename F, typename Expr,
322+
require_convertible_t<Expr&, MatrixType>* = nullptr,
323+
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
324+
inline void inplace_ops_impl(F&& f, Expr&& other) {
325+
auto&& x = to_ref(other);
326+
for (int k = 0; k < (*this).outerSize(); ++k) {
327+
typename Base::InnerIterator it(*this, k);
328+
typename std::decay_t<Expr>::InnerIterator iz(x, k);
329+
for (; static_cast<bool>(it) && static_cast<bool>(iz); ++it, ++iz) {
330+
f(it.valueRef(), iz.value());
331+
}
332+
}
333+
}
334+
335+
/**
336+
* inplace operations functor for `Sparse (.*)= Dense`.
337+
* @note This assumes the user intends to perform the inplace operation for
338+
* the nonzero parts of `this`
339+
* @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
340+
* method
341+
* @tparam Expr Type derived from `Eigen::DenseBase`
342+
* @param f Functor that performs the inplace operation
343+
* @param other The right hand side of the inplace operation
344+
*/
345+
template <typename F, typename Expr,
346+
require_not_convertible_t<Expr&, MatrixType>* = nullptr,
347+
require_eigen_dense_base_t<Expr>* = nullptr>
348+
inline void inplace_ops_impl(F&& f, Expr&& other) {
349+
auto&& x = to_ref(other);
350+
for (int k = 0; k < (*this).outerSize(); ++k) {
351+
typename Base::InnerIterator it(*this, k);
352+
for (; static_cast<bool>(it); ++it) {
353+
f(it.valueRef(), x(it.row(), it.col()));
354+
}
355+
}
356+
}
357+
358+
/**
359+
* inplace operations functor for `Sparse (.*)= Scalar`.
360+
* @note This assumes the user intends to perform the inplace operation for
361+
* the nonzero parts of `this`
362+
* @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
363+
* method
364+
* @tparam T A scalar type
365+
* @param f Functor that performs the inplace operation
366+
* @param other The right hand side of the inplace operation
367+
*/
368+
template <typename F, typename T,
369+
require_convertible_t<T&, Scalar>* = nullptr>
370+
inline void inplace_ops_impl(F&& f, T&& other) {
371+
auto* val_ptr = (*this).valuePtr();
372+
const auto non_zeros = (*this).nonZeros();
373+
for (Eigen::Index i = 0; i < non_zeros; ++i) {
374+
f(val_ptr[i], other);
375+
}
376+
}
377+
378+
public:
379+
/**
380+
* Inplace operator `+=`
381+
* @note Caution! Inplace operators assume that either
382+
* 1. The right hand side sparse matrix has the same sparsity pattern
383+
* 2. You only intend to add a scalar or dense matrix coefficients to the
384+
* nonzero values of `this`. This is intended to be used within the reverse
385+
* pass for accumulation of the adjoint and is built as such. Any other use
386+
* case should be be sure that the above assumptions are satisfied.
387+
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
388+
* `Eigen::DenseMatrixBase` or a `Scalar`
389+
* @param other value to be added inplace to the matrix.
390+
*/
391+
template <typename T>
392+
inline arena_matrix& operator+=(T&& other) {
393+
inplace_ops_impl(
394+
[](auto&& x, auto&& y) {
395+
x += y;
396+
return;
397+
},
398+
std::forward<T>(other));
399+
return *this;
400+
}
401+
402+
/**
403+
* Inplace operator `+=`
404+
* @note Caution! Inplace operators assume that either
405+
* 1. The right hand side sparse matrix has the same sparsity pattern
406+
* 2. You only intend to add a scalar or dense matrix coefficients to the
407+
* nonzero values of `this`. This is intended to be used within the reverse
408+
* pass for accumulation of the adjoint and is built as such. Any other use
409+
* case should be be sure that the above assumptions are satisfied.
410+
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
411+
* `Eigen::DenseMatrixBase` or a `Scalar`
412+
* @param other value to be added inplace to the matrix.
413+
*/
414+
template <typename T>
415+
inline arena_matrix& operator-=(T&& other) {
416+
inplace_ops_impl(
417+
[](auto&& x, auto&& y) {
418+
x -= y;
419+
return;
420+
},
421+
std::forward<T>(other));
422+
return *this;
423+
}
424+
};
425+
142426
} // namespace math
143427
} // namespace stan
144428

stan/math/rev/core/var_value_fwd_declare.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ namespace math {
66
// forward declaration of var
77
template <typename T, typename = void>
88
class var_value;
9+
10+
/**
11+
* Equivalent to `Eigen::Matrix`, except that the data is stored on AD stack.
12+
* That makes these objects trivially destructible and usable in `vari`s.
13+
*
14+
* @tparam MatrixType Eigen matrix type this works as (`MatrixXd`, `VectorXd`,
15+
* ...)
16+
*/
17+
template <typename MatrixType, typename = void>
18+
class arena_matrix;
919
} // namespace math
1020
} // namespace stan
1121
#endif

0 commit comments

Comments
 (0)