Skip to content

CholeskyMultiFactoriser #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 157 additions & 44 deletions include/tensor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ inline bool DTensor<T>::upload(const std::vector<T> &vec, StorageMode mode) {
size_t size = vec.size();
size_t thisSize = m_numRows * m_numCols * m_numMats;
// make sure vec is of right size
if (size != thisSize) throw std::invalid_argument("vec has wrong size");
if (size != thisSize) throw std::invalid_argument("[upload] vec has wrong size");
std::vector<T> vecCm(thisSize);
if (mode == StorageMode::rowMajor) {
rm2cm(vec, vecCm);
Expand Down Expand Up @@ -611,7 +611,7 @@ inline DTensor<double> DTensor<double>::tr() const {
template<typename T>
inline void DTensor<T>::deviceCopyTo(DTensor<T> &elsewhere) const {
if (elsewhere.numEl() < numEl()) {
throw std::invalid_argument("tensor does not fit into destination");
throw std::invalid_argument("[deviceCopyTo] tensor does not fit into destination");
}
gpuErrChk(cudaMemcpy(elsewhere.raw(),
m_d_data,
Expand All @@ -623,7 +623,7 @@ template<>
inline DTensor<double> &DTensor<double>::operator*=(double scalar) {
double alpha = scalar;
gpuErrChk(
cublasDscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1));
cublasDscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1));
return *this;
}

Expand All @@ -641,25 +641,25 @@ template<>
inline DTensor<float> &DTensor<float>::operator*=(float scalar) {
float alpha = scalar;
gpuErrChk(
cublasSscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1));
cublasSscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1));
return *this;
}

template<>
inline DTensor<double> &DTensor<double>::operator+=(const DTensor<double> &rhs) {
const double alpha = 1.;
gpuErrChk(
cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
return *this;
}

template<>
inline DTensor<float> &DTensor<float>::operator+=(const DTensor<float> &rhs) {
const float alpha = 1.;
gpuErrChk(
cublasSaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
cublasSaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
return *this;
}

Expand All @@ -675,8 +675,8 @@ template<>
inline DTensor<double> &DTensor<double>::operator-=(const DTensor<double> &rhs) {
const double alpha = -1.;
gpuErrChk(
cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data,
1, m_d_data, 1));
return *this;
}

Expand Down Expand Up @@ -745,13 +745,13 @@ inline void DTensor<double>::leastSquares(DTensor &B) {
size_t batchSize = numMats();
size_t nColsB = B.numCols();
if (B.numRows() != m_numRows)
throw std::invalid_argument("Least squares rhs rows does not equal lhs rows");
throw std::invalid_argument("[Least squares] rhs rows does not equal lhs rows");
if (nColsB != 1)
throw std::invalid_argument("Least squares rhs are not vectors");
throw std::invalid_argument("[Least squares] rhs are not vectors");
if (B.numMats() != batchSize)
throw std::invalid_argument("Least squares rhs numMats does not equal lhs numMats");
throw std::invalid_argument("[Least squares] rhs numMats does not equal lhs numMats");
if (m_numCols > m_numRows)
throw std::invalid_argument("Least squares supports square or tall matrices only");
throw std::invalid_argument("[Least squares] supports square or tall matrices only");
int info = 0;
DTensor<int> infoArray(batchSize);
DTensor<double *> As = pointersToMatrices();
Expand All @@ -775,13 +775,13 @@ inline void DTensor<float>::leastSquares(DTensor &B) {
size_t batchSize = numMats();
size_t nColsB = B.numCols();
if (B.numRows() != m_numRows)
throw std::invalid_argument("Least squares rhs rows does not equal lhs rows");
throw std::invalid_argument("[Least squares] rhs rows does not equal lhs rows");
if (nColsB != 1)
throw std::invalid_argument("Least squares rhs are not vectors");
throw std::invalid_argument("[Least squares] rhs are not vectors");
if (B.numMats() != batchSize)
throw std::invalid_argument("Least squares rhs numMats does not equal lhs numMats");
throw std::invalid_argument("[Least squares] rhs numMats does not equal lhs numMats");
if (m_numCols > m_numRows)
throw std::invalid_argument("Least squares supports square or tall matrices only");
throw std::invalid_argument("[Least squares] supports square or tall matrices only");
int info = 0;
DTensor<int> infoArray(batchSize);
DTensor<float *> As = pointersToMatrices();
Expand Down Expand Up @@ -899,7 +899,7 @@ private:
*/
void checkMatrix(DTensor<T> &tensor) const {
if (tensor.numRows() < tensor.numCols()) {
throw std::invalid_argument("your matrix is fat (no offence)");
throw std::invalid_argument("[svd] your matrix is fat (no offence)");
}
};

Expand Down Expand Up @@ -1021,17 +1021,17 @@ inline bool Svd<double>::factorise() {
if (m_computeU)
Ui = std::make_unique<DTensor<double>>(*m_U, 2, i, i);
gpuErrChk(
cusolverDnDgesvd(Session::getInstance().cuSolverHandle(),
(m_computeU) ? 'A' : 'N', 'A',
m, n,
Ai.raw(), m,
Si.raw(),
(m_computeU) ? Ui->raw() : nullptr, m,
Vtri.raw(), n,
m_workspace->raw(),
m_lwork,
nullptr, // rwork (used only if SVD fails)
m_info->raw()));
cusolverDnDgesvd(Session::getInstance().cuSolverHandle(),
(m_computeU) ? 'A' : 'N', 'A',
m, n,
Ai.raw(), m,
Si.raw(),
(m_computeU) ? Ui->raw() : nullptr, m,
Vtri.raw(), n,
m_workspace->raw(),
m_lwork,
nullptr, // rwork (used only if SVD fails)
m_info->raw()));
info = info && ((*m_info)(0, 0, 0) == 0);
}
return info;
Expand All @@ -1051,17 +1051,17 @@ inline bool Svd<float>::factorise() {
if (m_computeU)
Ui = std::make_unique<DTensor<float>>(*m_U, 2, i, i);
gpuErrChk(
cusolverDnSgesvd(Session::getInstance().cuSolverHandle(),
(m_computeU) ? 'A' : 'N', 'A',
m, n,
Ai.raw(), m,
Si.raw(),
(m_computeU) ? Ui->raw() : nullptr, m,
Vtri.raw(), n,
m_workspace->raw(),
m_lwork,
nullptr, // rwork (used only if SVD fails)
m_info->raw()));
cusolverDnSgesvd(Session::getInstance().cuSolverHandle(),
(m_computeU) ? 'A' : 'N', 'A',
m, n,
Ai.raw(), m,
Si.raw(),
(m_computeU) ? Ui->raw() : nullptr, m,
Vtri.raw(), n,
m_workspace->raw(),
m_lwork,
nullptr, // rwork (used only if SVD fails)
m_info->raw()));
info = info && ((*m_info)(0, 0, 0) == 0);
}
return info;
Expand Down Expand Up @@ -1095,8 +1095,8 @@ private:
public:

CholeskyFactoriser(DTensor<T> &A) {
if (A.numMats() > 1) throw std::invalid_argument("3D tensors are not supported (for now); only matrices");
if (A.numRows() != A.numCols()) throw std::invalid_argument("Matrix A must be square for CF");
if (A.numMats() > 1) throw std::invalid_argument("[Cholesky] 3D tensors require `CholeskyBatchFactoriser");
if (A.numRows() != A.numCols()) throw std::invalid_argument("[Cholesky] Matrix A must be square");
m_matrix = &A;
computeWorkspaceSize();
m_workspace = std::make_unique<DTensor<T>>(m_workspaceSize);
Expand Down Expand Up @@ -1237,7 +1237,7 @@ template<typename T>
TEMPLATE_CONSTRAINT_REQUIRES_FPX
inline Nullspace<T>::Nullspace(DTensor<T> &a) {
size_t m = a.numRows(), n = a.numCols(), nMats = a.numMats();
if (m > n) throw std::invalid_argument("I was expecting a square or fat matrix");
if (m > n) throw std::invalid_argument("[nullspace] I was expecting a square or fat matrix");
m_nullspace = std::make_unique<DTensor<T>>(n, n, nMats, true);
m_projOp = std::make_unique<DTensor<T>>(n, n, nMats, true);
auto aTranspose = a.tr();
Expand Down Expand Up @@ -1276,4 +1276,117 @@ inline void Nullspace<T>::project(DTensor<T> &b) {
}


/* ================================================================================================
* CHOLESKY BATCH FACTORISATION (CBF)
* ================================================================================================ */

/**
* Cholesky Factorisation (CF) for batches of matrices (tensors).
* This object can factorise and then solve,
* or you can provide pre-computed Cholesky lower-triangular matrices and then solve.
* @tparam T data type of tensor (must be float or double)
*/
TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX
class CholeskyBatchFactoriser {

private:
DTensor<T> *m_matrix; ///< Matrix to factorise or lower-triangular decomposed matrix
size_t m_numRows = 0; ///< Number of rows in `m_matrix`
size_t m_numMats = 0; ///< Number of matrices in `m_matrix`
std::unique_ptr<DTensor<int>> m_deviceInfo; ///< Info from cusolver functions
bool m_factorisationDone = false; ///< Whether `m_matrix` holds original or lower-triangular matrix

public:
CholeskyBatchFactoriser() = delete;

/**
* Constructor
* @param A either matrices to be factorised or lower-triangular Cholesky decomposition matrices
* @param factorised whether A is the original matrices or the factorised ones (default=original matrices)
*/
CholeskyBatchFactoriser(DTensor<T> &A, bool factorised=false) : m_factorisationDone(factorised) {
if (A.numRows() != A.numCols()) throw std::invalid_argument("[CholeskyBatch] A must be square");
m_matrix = &A;
m_numRows = A.numRows();
m_numMats = A.numMats();
m_deviceInfo = std::make_unique<DTensor<int>>(m_numMats, 1, 1, true);
}

/**
* Factorise all matrices in tensor (batched).
*/
void factorise();

/**
* Solve the system A\b using Cholesky lower-triangular matrix of A (batched).
* @param b vector in system Ax=b to be solved where x is the solution
*/
void solve(DTensor<T> &b);

};

template<>
void CholeskyBatchFactoriser<double>::factorise() {
if (m_factorisationDone) return;
DTensor<double *> ptrA = m_matrix->pointersToMatrices();
gpuErrChk(cusolverDnDpotrfBatched(Session::getInstance().cuSolverHandle(),
CUBLAS_FILL_MODE_LOWER,
m_numRows,
ptrA.raw(),
m_numRows,
m_deviceInfo->raw(),
m_numMats));
m_factorisationDone = true;
}

template<>
void CholeskyBatchFactoriser<float>::factorise() {
if (m_factorisationDone) return;
DTensor<float *> ptrA = m_matrix->pointersToMatrices();
gpuErrChk(cusolverDnSpotrfBatched(Session::getInstance().cuSolverHandle(),
CUBLAS_FILL_MODE_LOWER,
m_numRows,
ptrA.raw(),
m_numRows,
m_deviceInfo->raw(),
m_numMats));
m_factorisationDone = true;
}

template<>
void CholeskyBatchFactoriser<double>::solve(DTensor<double> &b) {
if (!m_factorisationDone) throw std::logic_error("[CholeskyBatchSolve] no factor to solve with");
if (b.numCols() != 1) throw std::invalid_argument("[CholeskyBatchSolve] only supports `b` with one column");
DTensor<double *> ptrA = m_matrix->pointersToMatrices();
DTensor<double *> ptrB = b.pointersToMatrices();
gpuErrChk(cusolverDnDpotrsBatched(Session::getInstance().cuSolverHandle(),
CUBLAS_FILL_MODE_LOWER,
m_numRows,
1, ///< only supports rhs = 1
ptrA.raw(),
m_numRows,
ptrB.raw(),
m_numRows,
m_deviceInfo->raw(),
m_numMats));
}

template<>
void CholeskyBatchFactoriser<float>::solve(DTensor<float> &b) {
if (!m_factorisationDone) throw std::logic_error("[CholeskyBatchSolve] no factor to solve with");
if (b.numCols() != 1) throw std::invalid_argument("[CholeskyBatchSolve] only supports `b` with one column");
DTensor<float *> ptrA = m_matrix->pointersToMatrices();
DTensor<float *> ptrB = b.pointersToMatrices();
gpuErrChk(cusolverDnSpotrsBatched(Session::getInstance().cuSolverHandle(),
CUBLAS_FILL_MODE_LOWER,
m_numRows,
1, ///< only supports rhs = 1
ptrA.raw(),
m_numRows,
ptrB.raw(),
m_numRows,
m_deviceInfo->raw(),
m_numMats));
}

#endif
Loading