diff --git a/src/polysolve/linear/CMakeLists.txt b/src/polysolve/linear/CMakeLists.txt index a742a357..43d2905e 100644 --- a/src/polysolve/linear/CMakeLists.txt +++ b/src/polysolve/linear/CMakeLists.txt @@ -15,6 +15,8 @@ set(SOURCES Pardiso.hpp SaddlePointSolver.cpp SaddlePointSolver.hpp + SPQR.cpp + SPQR.hpp ) source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Source Files" FILES ${SOURCES}) diff --git a/src/polysolve/linear/SPQR.cpp b/src/polysolve/linear/SPQR.cpp new file mode 100644 index 00000000..f94b123e --- /dev/null +++ b/src/polysolve/linear/SPQR.cpp @@ -0,0 +1,20 @@ +#include "SPQR.hpp" + +namespace polysolve::linear +{ + template <> + void EigenDirect>::analyze_pattern(const StiffnessMatrix &A, const int precond_num) + { + m_Solver.compute(A); + } + template <> + void EigenDirect>::factorize(const StiffnessMatrix &A) + { + m_Solver.compute(A); + if (m_Solver.info() == Eigen::NumericalIssue) + { + throw std::runtime_error("[EigenDirect] NumericalIssue encountered."); + } + } + +} // namespace polysolve::linear diff --git a/src/polysolve/linear/SPQR.hpp b/src/polysolve/linear/SPQR.hpp new file mode 100644 index 00000000..acedb6c6 --- /dev/null +++ b/src/polysolve/linear/SPQR.hpp @@ -0,0 +1,96 @@ +#pragma once +#ifdef POLYSOLVE_WITH_SPQR +#include +#include +#include "EigenSolver.hpp" +#include "Solver.hpp" +namespace polysolve::linear +{ + template <> + void EigenDirect>::analyze_pattern(const StiffnessMatrix &A, const int precond_num); + template <> + void EigenDirect>::factorize(const StiffnessMatrix &A); + + class SPQRSolver : public EigenDirect> + { + + StiffnessMatrix matrixQ() const; + }; +} // namespace polysolve::linear +namespace Eigen +{ + template + struct SPQR_QSparseProduct; + namespace internal + { + template + struct traits> + { + typedef typename Derived::PlainObject ReturnType; + }; + } // namespace internal + template <> + struct SPQRMatrixQReturnType> + { + + using SPQRType = SPQR; + SPQRMatrixQReturnType(const SPQRType &spqr) : m_spqr(spqr) {} + template + SPQR_QProduct operator*(const MatrixBase &other) + { + return SPQR_QProduct(m_spqr, other.derived(), false); + } + template + SPQR_QSparseProduct operator*(const SparseMatrixBase &other) + { + return SPQR_QSparseProduct(m_spqr, other.derived(), false); + } + SPQRMatrixQTransposeReturnType adjoint() const + { + return SPQRMatrixQTransposeReturnType(m_spqr); + } + // To use for operations with the transpose of Q + SPQRMatrixQTransposeReturnType transpose() const + { + return SPQRMatrixQTransposeReturnType(m_spqr); + } + const SPQRType &m_spqr; + }; + + template + struct SPQR_QSparseProduct : ReturnByValue> + { + struct SPQRTypeWrap : public SPQRType + { + using SPQRType::m_H; + using SPQRType::m_HPinv; + using SPQRType::m_HTau; + }; + typedef typename SPQRType::Scalar Scalar; + typedef typename SPQRType::StorageIndex StorageIndex; + // Define the constructor to get reference to argument types + SPQR_QSparseProduct(const SPQRType &spqr, const Derived &other, bool transpose) : m_spqr(spqr), m_other(other), m_transpose(transpose) {} + + const SPQRTypeWrap &spqr_w() const { return reinterpret_cast(m_spqr); } + + inline Index rows() const { return m_transpose ? m_spqr.rows() : m_spqr.cols(); } + inline Index cols() const { return m_other.cols(); } + // Assign to a vector + template + void evalTo(ResType &res) const + { + cholmod_sparse y_cd; + cholmod_sparse *x_cd; + int method = m_transpose ? SPQR_QTX : SPQR_QX; + cholmod_common *cc = m_spqr.cholmodCommon(); + y_cd = viewAsCholmod(m_other.const_cast_derived()); + x_cd = SuiteSparseQR_qmult(method, spqr_w().m_H, spqr_w().m_HTau, spqr_w().m_HPinv, &y_cd, cc); + res = viewAsEigen(*x_cd); + cholmod_l_free_sparse(&x_cd, cc); + } + const SPQRType &m_spqr; + const Derived &m_other; + bool m_transpose; + }; +} // namespace Eigen +#endif diff --git a/src/polysolve/linear/Solver.cpp b/src/polysolve/linear/Solver.cpp index e7c0c259..1083b730 100644 --- a/src/polysolve/linear/Solver.cpp +++ b/src/polysolve/linear/Solver.cpp @@ -28,22 +28,7 @@ template #endif #ifdef POLYSOLVE_WITH_SPQR -#include -namespace polysolve::linear { - template <> - void EigenDirect>::analyze_pattern(const StiffnessMatrix& A, const int precond_num) { - m_Solver.compute(A); - } - template <> - void EigenDirect>::factorize(const StiffnessMatrix &A) - { - m_Solver.compute(A); - if (m_Solver.info() == Eigen::NumericalIssue) - { - throw std::runtime_error("[EigenDirect] NumericalIssue encountered."); - } - } -} +#include "SPQR.hpp" #endif #ifdef POLYSOLVE_WITH_SUPERLU #include diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8f7b56ca..193a2103 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -33,6 +33,10 @@ endif() include(polyfem-data) target_link_libraries(unit_tests PRIVATE polyfem::data) +if(POLYSOLVE_WITH_SPQR) + target_link_libraries(unit_tests PRIVATE SuiteSparse::SPQR) +endif() + ################################################################################ # Register tests ################################################################################ diff --git a/tests/test_linear_solver.cpp b/tests/test_linear_solver.cpp index 77a9f169..c6bfd9af 100644 --- a/tests/test_linear_solver.cpp +++ b/tests/test_linear_solver.cpp @@ -1,6 +1,7 @@ ////////////////////////////////////////////////////////////////////////// #include #include +#include #include @@ -856,3 +857,67 @@ TEST_CASE("cusolverdn_5cubes", "[solver]") REQUIRE(err < 1e-8); } } + +TEST_CASE("spqr_sparse_product", "[solver]") +{ + + Eigen::MatrixXd A(4, 4); + for (int i = 0; i < 4; i++) + { + A(i, i) = 1.0; + } + A(0, 1) = 1.0; + A(3, 0) = 1.0; + std::unique_ptr solver; + try + { + solver = Solver::create("Eigen::SPQR", ""); + } + catch (const std::exception &) + { + return; + } + + using Type = EigenDirect>; + + Type *typed_solver = dynamic_cast(solver.get()); + REQUIRE(typed_solver != nullptr); + // solver->set_parameters(params); + + for (int i = 0; i < 5; ++i) + { + A = Eigen::MatrixXd::Random(5, 5); + auto As = A.sparseView().eval(); + + // do a qr so i have a q + Eigen::SPQR spqr(As); + auto Q = spqr.matrixQ(); + + // get a random matrix to multiply against + Eigen::MatrixXd B(A.rows(), 5); + B.setRandom(); + Eigen::MatrixXd dense = Q * B; + + // make a sparse version + auto Bs = B.sparseView().eval(); + StiffnessMatrix sparse = Q * Bs; + + // check that the result of the product is the same + CHECK((dense - sparse).norm() < 1e-10); + + // try to extract the Q matrix as a dense matrix + Eigen::MatrixXd Id = Eigen::MatrixXd::Identity(A.rows(), A.rows()); + Eigen::MatrixXd denseQ = Q * Id; + + // use the product with B to get a weak equivalence once again + Eigen::MatrixXd dense2 = denseQ * B; + CHECK((dense2 - dense).norm() < 1e-10); + + // now try using a sparse product + StiffnessMatrix I(A.rows(), A.rows()); + I.setIdentity(); + StiffnessMatrix myQ = Q * I; + StiffnessMatrix sparse2 = myQ * Bs; + CHECK((sparse2 - sparse).norm() < 1e-10); + } +}