|
4 | 4 | #include <stan/math/prim/fun/Eigen.hpp>
|
5 | 5 | #include <stan/math/rev/core/chainable_alloc.hpp>
|
6 | 6 | #include <stan/math/rev/core/chainablestack.hpp>
|
| 7 | +#include <stan/math/rev/core/var_value_fwd_declare.hpp> |
| 8 | +#include <stan/math/prim/fun/to_ref.hpp> |
7 | 9 |
|
8 | 10 | namespace stan {
|
9 | 11 | namespace math {
|
10 | 12 |
|
11 |
| -/** |
12 |
| - * Equivalent to `Eigen::Matrix`, except that the data is stored on AD stack. |
13 |
| - * That makes these objects triviali destructible and usable in `vari`s. |
14 |
| - * |
15 |
| - * @tparam MatrixType Eigen matrix type this works as (`MatrixXd`, `VectorXd` |
16 |
| - * ...) |
17 |
| - */ |
18 | 13 | template <typename MatrixType>
|
19 |
| -class arena_matrix : public Eigen::Map<MatrixType> { |
| 14 | +class arena_matrix<MatrixType, require_eigen_dense_base_t<MatrixType>> |
| 15 | + : public Eigen::Map<MatrixType> { |
20 | 16 | public:
|
21 | 17 | using Scalar = value_type_t<MatrixType>;
|
22 | 18 | using Base = Eigen::Map<MatrixType>;
|
@@ -139,6 +135,294 @@ class arena_matrix : public Eigen::Map<MatrixType> {
|
139 | 135 | }
|
140 | 136 | };
|
141 | 137 |
|
| 138 | +template <typename MatrixType> |
| 139 | +class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>> |
| 140 | + : public Eigen::Map<MatrixType> { |
| 141 | + public: |
| 142 | + using Scalar = value_type_t<MatrixType>; |
| 143 | + using Base = Eigen::Map<MatrixType>; |
| 144 | + using PlainObject = std::decay_t<MatrixType>; |
| 145 | + using StorageIndex = typename PlainObject::StorageIndex; |
| 146 | + |
| 147 | + /** |
| 148 | + * Default constructor. |
| 149 | + */ |
| 150 | + arena_matrix() : Base::Map(0, 0, 0, nullptr, nullptr, nullptr) {} |
| 151 | + /** |
| 152 | + * Constructor for CDC formatted data. Assumes compressed and does not own |
| 153 | + * memory. |
| 154 | + * |
| 155 | + * Constructs a read-write Map to a sparse matrix of size rows x cols, |
| 156 | + * containing nnz non-zero coefficients, stored as a sparse format as defined |
| 157 | + * by the pointers outerIndexPtr, innerIndexPtr, and valuePtr. If the optional |
| 158 | + * parameter innerNonZerosPtr is the null pointer, then a standard compressed |
| 159 | + * format is assumed. |
| 160 | + * |
| 161 | + * @param rows Number of rows |
| 162 | + * @param cols Number of columns |
| 163 | + * @param nnz Number of nonzero items in matrix |
| 164 | + * @param outerIndexPtr A pointer to the array of outer indices. |
| 165 | + * @param innerIndexPtr A pointer to the array of inner indices. |
| 166 | + * @param valuePtr A pointer to the array of values. |
| 167 | + * @param innerNonZerosPtr A pointer to the array of the number of non zeros |
| 168 | + * of the inner vectors. |
| 169 | + * |
| 170 | + */ |
| 171 | + arena_matrix(Eigen::Index rows, Eigen::Index cols, Eigen::Index nnz, |
| 172 | + StorageIndex* outerIndexPtr, StorageIndex* innerIndexPtr, |
| 173 | + Scalar* valuePtr, StorageIndex* innerNonZerosPtr = nullptr) |
| 174 | + : Base::Map(rows, cols, nnz, outerIndexPtr, innerIndexPtr, valuePtr, |
| 175 | + innerNonZerosPtr) {} |
| 176 | + |
| 177 | + private: |
| 178 | + template <typename T, typename Integral> |
| 179 | + inline T* copy_vector(const T* ptr, Integral size) { |
| 180 | + if (size == 0) |
| 181 | + return nullptr; |
| 182 | + T* ret = ChainableStack::instance_->memalloc_.alloc_array<T>(size); |
| 183 | + std::copy_n(ptr, size, ret); |
| 184 | + return ret; |
| 185 | + } |
| 186 | + |
| 187 | + public: |
| 188 | + /** |
| 189 | + * Constructs `arena_matrix` from an `Eigen::SparseMatrix`. |
| 190 | + * @param other Eigen Sparse Matrix class |
| 191 | + */ |
| 192 | + template <typename T, require_same_t<T, PlainObject>* = nullptr> |
| 193 | + arena_matrix(T&& other) // NOLINT |
| 194 | + : Base::Map( |
| 195 | + other.rows(), other.cols(), other.nonZeros(), |
| 196 | + copy_vector(other.outerIndexPtr(), other.outerSize() + 1), |
| 197 | + copy_vector(other.innerIndexPtr(), other.nonZeros()), |
| 198 | + copy_vector(other.valuePtr(), other.nonZeros()), |
| 199 | + copy_vector( |
| 200 | + other.innerNonZeroPtr(), |
| 201 | + other.innerNonZeroPtr() == nullptr ? 0 : other.innerSize())) {} |
| 202 | + |
| 203 | + /** |
| 204 | + * Constructs `arena_matrix` from an Eigen expression |
| 205 | + * @param other An expression |
| 206 | + */ |
| 207 | + template <typename S, require_convertible_t<S&, PlainObject>* = nullptr, |
| 208 | + require_not_same_t<S, PlainObject>* = nullptr> |
| 209 | + arena_matrix(S&& other) // NOLINT |
| 210 | + : arena_matrix(PlainObject(std::forward<S>(other))) {} |
| 211 | + |
| 212 | + /** |
| 213 | + * Constructs `arena_matrix` from an expression. This makes an assumption that |
| 214 | + * any other `Eigen::Map` also contains memory allocated in the arena. |
| 215 | + * @param other expression |
| 216 | + */ |
| 217 | + arena_matrix(const Base& other) // NOLINT |
| 218 | + : Base(other) {} |
| 219 | + |
| 220 | + /** |
| 221 | + * Const copy constructor. No actual copy is performed |
| 222 | + * @note Since the memory for the arena matrix sits in Stan's memory arena all |
| 223 | + * copies/moves of arena matrices are shallow moves |
| 224 | + * @param other matrix to copy from |
| 225 | + */ |
| 226 | + arena_matrix(const arena_matrix<MatrixType>& other) |
| 227 | + : Base::Map(other.rows(), other.cols(), other.nonZeros(), |
| 228 | + other.outerIndexPtr(), other.innerIndexPtr(), |
| 229 | + other.valuePtr(), other.innernonZeroPtr()) {} |
| 230 | + /** |
| 231 | + * Move constructor. |
| 232 | + * @note Since the memory for the arena matrix sits in Stan's memory arena all |
| 233 | + * copies/moves of arena matrices are shallow moves |
| 234 | + * @param other matrix to copy from |
| 235 | + */ |
| 236 | + arena_matrix(arena_matrix<MatrixType>&& other) |
| 237 | + : Base::Map(other.rows(), other.cols(), other.nonZeros(), |
| 238 | + other.outerIndexPtr(), other.innerIndexPtr(), |
| 239 | + other.valuePtr(), other.innerNonZeroPtr()) {} |
| 240 | + /** |
| 241 | + * Copy constructor. No actual copy is performed |
| 242 | + * @note Since the memory for the arena matrix sits in Stan's memory arena all |
| 243 | + * copies/moves of arena matrices are shallow moves |
| 244 | + * @param other matrix to copy from |
| 245 | + */ |
| 246 | + arena_matrix(arena_matrix<MatrixType>& other) |
| 247 | + : Base::Map(other.rows(), other.cols(), other.nonZeros(), |
| 248 | + other.outerIndexPtr(), other.innerIndexPtr(), |
| 249 | + other.valuePtr(), other.innerNonZeroPtr()) {} |
| 250 | + |
| 251 | + // without this using, compiler prefers combination of implicit construction |
| 252 | + // and copy assignment to the inherited operator when assigned an expression |
| 253 | + using Base::operator=; |
| 254 | + |
| 255 | + /** |
| 256 | + * Copy assignment operator. |
| 257 | + * @tparam ArenaMatrix An `arena_matrix` type |
| 258 | + * @param other matrix to copy from |
| 259 | + * @return `*this` |
| 260 | + */ |
| 261 | + template <typename ArenaMatrix, |
| 262 | + require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr> |
| 263 | + arena_matrix& operator=(ArenaMatrix&& other) { |
| 264 | + // placement new changes what data map points to - there is no allocation |
| 265 | + new (this) Base(other.rows(), other.cols(), other.nonZeros(), |
| 266 | + const_cast<StorageIndex*>(other.outerIndexPtr()), |
| 267 | + const_cast<StorageIndex*>(other.innerIndexPtr()), |
| 268 | + const_cast<Scalar*>(other.valuePtr()), |
| 269 | + const_cast<StorageIndex*>(other.innerNonZeroPtr())); |
| 270 | + return *this; |
| 271 | + } |
| 272 | + |
| 273 | + /** |
| 274 | + * Assignment operator for assigning an expression. |
| 275 | + * @tparam Expr An expression that an `arena_matrix<MatrixType>` can be |
| 276 | + * constructed from |
| 277 | + * @param expr expression to evaluate into this |
| 278 | + * @return `*this` |
| 279 | + */ |
| 280 | + template <typename Expr, |
| 281 | + require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr> |
| 282 | + arena_matrix& operator=(Expr&& expr) { |
| 283 | + *this = arena_matrix(std::forward<Expr>(expr)); |
| 284 | + return *this; |
| 285 | + } |
| 286 | + |
| 287 | + private: |
| 288 | + /** |
| 289 | + * inplace operations functor for `Sparse (.*)= Sparse`. |
| 290 | + * @note This assumes that each matrix is of the same size and sparsity |
| 291 | + * pattern. |
| 292 | + * @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)` |
| 293 | + * method |
| 294 | + * @tparam Expr Type derived from `Eigen::SparseMatrixBase` |
| 295 | + * @param f Functor that performs the inplace operation |
| 296 | + * @param other The right hand side of the inplace operation |
| 297 | + */ |
| 298 | + template <typename F, typename Expr, |
| 299 | + require_convertible_t<Expr&, MatrixType>* = nullptr, |
| 300 | + require_same_t<Expr, arena_matrix<MatrixType>>* = nullptr> |
| 301 | + inline void inplace_ops_impl(F&& f, Expr&& other) { |
| 302 | + auto&& x = to_ref(other); |
| 303 | + auto* val_ptr = (*this).valuePtr(); |
| 304 | + auto* x_val_ptr = x.valuePtr(); |
| 305 | + const auto non_zeros = (*this).nonZeros(); |
| 306 | + for (Eigen::Index i = 0; i < non_zeros; ++i) { |
| 307 | + f(val_ptr[i], x_val_ptr[i]); |
| 308 | + } |
| 309 | + } |
| 310 | + |
| 311 | + /** |
| 312 | + * inplace operations functor for `Sparse (.*)= Sparse`. |
| 313 | + * @note This assumes that each matrix is of the same size and sparsity |
| 314 | + * pattern. |
| 315 | + * @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)` |
| 316 | + * method |
| 317 | + * @tparam Expr Type derived from `Eigen::SparseMatrixBase` |
| 318 | + * @param f Functor that performs the inplace operation |
| 319 | + * @param other The right hand side of the inplace operation |
| 320 | + */ |
| 321 | + template <typename F, typename Expr, |
| 322 | + require_convertible_t<Expr&, MatrixType>* = nullptr, |
| 323 | + require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr> |
| 324 | + inline void inplace_ops_impl(F&& f, Expr&& other) { |
| 325 | + auto&& x = to_ref(other); |
| 326 | + for (int k = 0; k < (*this).outerSize(); ++k) { |
| 327 | + typename Base::InnerIterator it(*this, k); |
| 328 | + typename std::decay_t<Expr>::InnerIterator iz(x, k); |
| 329 | + for (; static_cast<bool>(it) && static_cast<bool>(iz); ++it, ++iz) { |
| 330 | + f(it.valueRef(), iz.value()); |
| 331 | + } |
| 332 | + } |
| 333 | + } |
| 334 | + |
| 335 | + /** |
| 336 | + * inplace operations functor for `Sparse (.*)= Dense`. |
| 337 | + * @note This assumes the user intends to perform the inplace operation for |
| 338 | + * the nonzero parts of `this` |
| 339 | + * @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)` |
| 340 | + * method |
| 341 | + * @tparam Expr Type derived from `Eigen::DenseBase` |
| 342 | + * @param f Functor that performs the inplace operation |
| 343 | + * @param other The right hand side of the inplace operation |
| 344 | + */ |
| 345 | + template <typename F, typename Expr, |
| 346 | + require_not_convertible_t<Expr&, MatrixType>* = nullptr, |
| 347 | + require_eigen_dense_base_t<Expr>* = nullptr> |
| 348 | + inline void inplace_ops_impl(F&& f, Expr&& other) { |
| 349 | + auto&& x = to_ref(other); |
| 350 | + for (int k = 0; k < (*this).outerSize(); ++k) { |
| 351 | + typename Base::InnerIterator it(*this, k); |
| 352 | + for (; static_cast<bool>(it); ++it) { |
| 353 | + f(it.valueRef(), x(it.row(), it.col())); |
| 354 | + } |
| 355 | + } |
| 356 | + } |
| 357 | + |
| 358 | + /** |
| 359 | + * inplace operations functor for `Sparse (.*)= Scalar`. |
| 360 | + * @note This assumes the user intends to perform the inplace operation for |
| 361 | + * the nonzero parts of `this` |
| 362 | + * @tparam F A type with a valid `operator()(Scalar& x, const Scalar& y)` |
| 363 | + * method |
| 364 | + * @tparam T A scalar type |
| 365 | + * @param f Functor that performs the inplace operation |
| 366 | + * @param other The right hand side of the inplace operation |
| 367 | + */ |
| 368 | + template <typename F, typename T, |
| 369 | + require_convertible_t<T&, Scalar>* = nullptr> |
| 370 | + inline void inplace_ops_impl(F&& f, T&& other) { |
| 371 | + auto* val_ptr = (*this).valuePtr(); |
| 372 | + const auto non_zeros = (*this).nonZeros(); |
| 373 | + for (Eigen::Index i = 0; i < non_zeros; ++i) { |
| 374 | + f(val_ptr[i], other); |
| 375 | + } |
| 376 | + } |
| 377 | + |
| 378 | + public: |
| 379 | + /** |
| 380 | + * Inplace operator `+=` |
| 381 | + * @note Caution! Inplace operators assume that either |
| 382 | + * 1. The right hand side sparse matrix has the same sparsity pattern |
| 383 | + * 2. You only intend to add a scalar or dense matrix coefficients to the |
| 384 | + * nonzero values of `this`. This is intended to be used within the reverse |
| 385 | + * pass for accumulation of the adjoint and is built as such. Any other use |
| 386 | + * case should be be sure that the above assumptions are satisfied. |
| 387 | + * @tparam T A type derived from `Eigen::SparseMatrixBase` or |
| 388 | + * `Eigen::DenseMatrixBase` or a `Scalar` |
| 389 | + * @param other value to be added inplace to the matrix. |
| 390 | + */ |
| 391 | + template <typename T> |
| 392 | + inline arena_matrix& operator+=(T&& other) { |
| 393 | + inplace_ops_impl( |
| 394 | + [](auto&& x, auto&& y) { |
| 395 | + x += y; |
| 396 | + return; |
| 397 | + }, |
| 398 | + std::forward<T>(other)); |
| 399 | + return *this; |
| 400 | + } |
| 401 | + |
| 402 | + /** |
| 403 | + * Inplace operator `+=` |
| 404 | + * @note Caution! Inplace operators assume that either |
| 405 | + * 1. The right hand side sparse matrix has the same sparsity pattern |
| 406 | + * 2. You only intend to add a scalar or dense matrix coefficients to the |
| 407 | + * nonzero values of `this`. This is intended to be used within the reverse |
| 408 | + * pass for accumulation of the adjoint and is built as such. Any other use |
| 409 | + * case should be be sure that the above assumptions are satisfied. |
| 410 | + * @tparam T A type derived from `Eigen::SparseMatrixBase` or |
| 411 | + * `Eigen::DenseMatrixBase` or a `Scalar` |
| 412 | + * @param other value to be added inplace to the matrix. |
| 413 | + */ |
| 414 | + template <typename T> |
| 415 | + inline arena_matrix& operator-=(T&& other) { |
| 416 | + inplace_ops_impl( |
| 417 | + [](auto&& x, auto&& y) { |
| 418 | + x -= y; |
| 419 | + return; |
| 420 | + }, |
| 421 | + std::forward<T>(other)); |
| 422 | + return *this; |
| 423 | + } |
| 424 | +}; |
| 425 | + |
142 | 426 | } // namespace math
|
143 | 427 | } // namespace stan
|
144 | 428 |
|
|
0 commit comments