@@ -176,7 +176,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
176
176
177
177
private:
178
178
template <typename T, typename Integral>
179
- auto * copy_vector (const T* ptr, Integral size) {
179
+ inline T* copy_vector (const T* ptr, Integral size) {
180
+ if (size == 0 ) return nullptr ;
180
181
T* ret = ChainableStack::instance_->memalloc_ .alloc_array <T>(size);
181
182
std::copy_n (ptr, size, ret);
182
183
return ret;
@@ -192,7 +193,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
192
193
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
193
194
copy_vector (other.outerIndexPtr(), other.outerSize() + 1),
194
195
copy_vector(other.innerIndexPtr(), other.nonZeros()),
195
- copy_vector(other.valuePtr(), other.nonZeros())) {}
196
+ copy_vector(other.valuePtr(), other.nonZeros()),
197
+ copy_vector(other.innerNonZeroPtr(), other.innerNonZeroPtr() == nullptr ? 0 : other.innerSize())) {}
196
198
197
199
/* *
198
200
* Constructs `arena_matrix` from an Eigen expression
@@ -220,7 +222,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
220
222
arena_matrix (const arena_matrix<MatrixType>& other)
221
223
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
222
224
other.outerIndexPtr(), other.innerIndexPtr(),
223
- other.valuePtr()) {}
225
+ other.valuePtr(), other.innernonZeroPtr() ) {}
224
226
/* *
225
227
* Move constructor.
226
228
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -230,7 +232,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
230
232
arena_matrix (arena_matrix<MatrixType>&& other)
231
233
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
232
234
other.outerIndexPtr(), other.innerIndexPtr(),
233
- other.valuePtr()) {}
235
+ other.valuePtr(), other.innerNonZeroPtr() ) {}
234
236
/* *
235
237
* Copy constructor. No actual copy is performed
236
238
* @note Since the memory for the arena matrix sits in Stan's memory arena all
@@ -240,7 +242,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
240
242
arena_matrix (arena_matrix<MatrixType>& other)
241
243
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
242
244
other.outerIndexPtr(), other.innerIndexPtr(),
243
- other.valuePtr()) {}
245
+ other.valuePtr(), other.innerNonZeroPtr() ) {}
244
246
245
247
// without this using, compiler prefers combination of implicit construction
246
248
// and copy assignment to the inherited operator when assigned an expression
@@ -259,7 +261,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
259
261
new (this ) Base (other.rows (), other.cols (), other.nonZeros (),
260
262
const_cast <StorageIndex*>(other.outerIndexPtr ()),
261
263
const_cast <StorageIndex*>(other.innerIndexPtr ()),
262
- const_cast <Scalar*>(other.valuePtr ()));
264
+ const_cast <Scalar*>(other.valuePtr ()),
265
+ const_cast <StorageIndex*>(other.innerNonZeroPtr ()));
263
266
return *this ;
264
267
}
265
268
@@ -289,8 +292,32 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
289
292
* @param other The right hand side of the inplace operation
290
293
*/
291
294
template <typename F, typename Expr,
292
- require_convertible_t <Expr&, MatrixType>* = nullptr >
293
- void inplace_ops_impl (F&& f, Expr&& other) {
295
+ require_convertible_t <Expr&, MatrixType>* = nullptr ,
296
+ require_same_t <Expr, arena_matrix<MatrixType>>* = nullptr >
297
+ inline void inplace_ops_impl (F&& f, Expr&& other) {
298
+ auto && x = to_ref (other);
299
+ auto * val_ptr = (*this ).valuePtr ();
300
+ auto * x_val_ptr = x.valuePtr ();
301
+ const auto non_zeros = (*this ).nonZeros ();
302
+ for (Eigen::Index i = 0 ; i < non_zeros; ++i) {
303
+ f (val_ptr[i], x_val_ptr[i]);
304
+ }
305
+ }
306
+
307
+ /* *
308
+ * inplace operations functor for `Sparse (.*)= Sparse`.
309
+ * @note This assumes that each matrix is of the same size and sparsity
310
+ * pattern.
311
+ * @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)`
312
+ * method
313
+ * @tparam Expr Type derived from `Eigen::SparseMatrixBase`
314
+ * @param f Functor that performs the inplace operation
315
+ * @param other The right hand side of the inplace operation
316
+ */
317
+ template <typename F, typename Expr,
318
+ require_convertible_t <Expr&, MatrixType>* = nullptr ,
319
+ require_not_same_t <Expr, arena_matrix<MatrixType>>* = nullptr >
320
+ inline void inplace_ops_impl (F&& f, Expr&& other) {
294
321
auto && x = to_ref (other);
295
322
for (int k = 0 ; k < (*this ).outerSize (); ++k) {
296
323
typename Base::InnerIterator it (*this , k);
@@ -314,7 +341,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
314
341
template <typename F, typename Expr,
315
342
require_not_convertible_t <Expr&, MatrixType>* = nullptr ,
316
343
require_eigen_dense_base_t <Expr>* = nullptr >
317
- void inplace_ops_impl (F&& f, Expr&& other) {
344
+ inline void inplace_ops_impl (F&& f, Expr&& other) {
318
345
auto && x = to_ref (other);
319
346
for (int k = 0 ; k < (*this ).outerSize (); ++k) {
320
347
typename Base::InnerIterator it (*this , k);
@@ -336,30 +363,29 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
336
363
*/
337
364
template <typename F, typename T,
338
365
require_convertible_t <T&, Scalar>* = nullptr >
339
- void inplace_ops_impl (F&& f, T&& other) {
340
- for (int k = 0 ; k < (*this ).outerSize (); ++k) {
341
- typename Base::InnerIterator it (*this , k);
342
- for (; static_cast <bool >(it); ++it) {
343
- f (it.valueRef (), other);
344
- }
366
+ inline void inplace_ops_impl (F&& f, T&& other) {
367
+ auto * val_ptr = (*this ).valuePtr ();
368
+ const auto non_zeros = (*this ).nonZeros ();
369
+ for (Eigen::Index i = 0 ; i < non_zeros; ++i) {
370
+ f (val_ptr[i], other);
345
371
}
346
372
}
347
373
348
374
public:
349
375
/* *
350
376
* Inplace operator `+=`
351
377
* @note Caution! Inplace operators assume that either
352
- * 1. The right hand side sparse matrix has the same sparcity pattern
378
+ * 1. The right hand side sparse matrix has the same sparsity pattern
353
379
* 2. You only intend to add a scalar or dense matrix coefficients to the
354
- * nonzero values of `this` This is intended to be used within the reverse
380
+ * nonzero values of `this`. This is intended to be used within the reverse
355
381
* pass for accumulation of the adjoint and is built as such. Any other use
356
- * case should be be sure the above assumptions are satisfied.
382
+ * case should be be sure that the above assumptions are satisfied.
357
383
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
358
384
* `Eigen::DenseMatrixBase` or a `Scalar`
359
385
* @param other value to be added inplace to the matrix.
360
386
*/
361
387
template <typename T>
362
- arena_matrix& operator +=(T&& other) {
388
+ inline arena_matrix& operator +=(T&& other) {
363
389
inplace_ops_impl (
364
390
[](auto && x, auto && y) {
365
391
x += y;
@@ -371,18 +397,18 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
371
397
372
398
/* *
373
399
* Inplace operator `+=`
374
- * @note Caution!! Inplace operators assume that either
375
- * 1. The right hand side sparse matrix has the same sparcity pattern
400
+ * @note Caution! Inplace operators assume that either
401
+ * 1. The right hand side sparse matrix has the same sparsity pattern
376
402
* 2. You only intend to add a scalar or dense matrix coefficients to the
377
- * nonzero values of `this` This is intended to be used within the reverse
403
+ * nonzero values of `this`. This is intended to be used within the reverse
378
404
* pass for accumulation of the adjoint and is built as such. Any other use
379
- * case should be be sure the above assumptions are satisfied.
405
+ * case should be be sure that the above assumptions are satisfied.
380
406
* @tparam T A type derived from `Eigen::SparseMatrixBase` or
381
407
* `Eigen::DenseMatrixBase` or a `Scalar`
382
408
* @param other value to be added inplace to the matrix.
383
409
*/
384
410
template <typename T>
385
- arena_matrix& operator -=(T&& other) {
411
+ inline arena_matrix& operator -=(T&& other) {
386
412
inplace_ops_impl (
387
413
[](auto && x, auto && y) {
388
414
x -= y;
0 commit comments