diff --git a/include/tensor.cuh b/include/tensor.cuh index a154d63..aeb9c60 100644 --- a/include/tensor.cuh +++ b/include/tensor.cuh @@ -17,7 +17,6 @@ */ #define DEFAULT_FPX double #define THREADS_PER_BLOCK 512 -#define DIM2BLOCKS(n) ((n) / THREADS_PER_BLOCK + ((n) % THREADS_PER_BLOCK != 0)) #if (__cplusplus >= 201703L) ///< if c++17 or above #define TEMPLATE_WITH_TYPE_T template #else @@ -29,6 +28,19 @@ #define TEMPLATE_CONSTRAINT_REQUIRES_FPX #endif +/** + * Determines the number of blocks needed for a given number of tasks, n, + * and number of threads per block + * + * @param n problem size + * @param threads_per_block threads per block (defaults to THREADS_PER_BLOCK) + * @return number of blocks + */ +constexpr size_t numBlocks(size_t n, size_t threads_per_block=THREADS_PER_BLOCK) { + return (n / threads_per_block + (n % threads_per_block != 0)); +} + + /** * Check for errors when calling GPU functions */ @@ -338,6 +350,8 @@ public: */ void addAB(const DTensor &A, const DTensor &B, T alpha = 1, T beta = 0); + void reshape(size_t newNumRows, size_t newNumCols, size_t newNumMats = 1); + /* ------------- OPERATORS ------------- */ DTensor &operator=(const DTensor &other); @@ -384,6 +398,21 @@ public: }; /* END OF DTENSOR */ +template +void DTensor::reshape(size_t newNumRows, size_t newNumCols, size_t newNumMats) { + size_t newNumElements = newNumRows * newNumCols * newNumMats; + if (numEl() != newNumElements) { + char errMessage[256]; + sprintf(errMessage, + "DTensor[%lu x %lu x %lu] with %lu elements cannot be reshaped into DTensor[%lu x %lu x %lu] (%lu elements)", + numRows(), numRows(), numMats(), numEl(), newNumRows, newNumCols, newNumMats, newNumElements); + throw std::invalid_argument(errMessage); + } + m_numRows = newNumRows; + m_numCols = newNumCols; + m_numMats = newNumMats; +} + template DTensor::DTensor(size_t m, size_t n, size_t k, bool zero) { m_numRows = m; @@ -623,7 +652,7 @@ template<> inline DTensor &DTensor::operator*=(double scalar) { double alpha = scalar; gpuErrChk( - cublasDscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1)); + cublasDscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1)); return *this; } @@ -641,7 +670,7 @@ template<> inline DTensor &DTensor::operator*=(float scalar) { float alpha = scalar; gpuErrChk( - cublasSscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1)); + cublasSscal(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, m_d_data, 1)); return *this; } @@ -649,8 +678,8 @@ template<> inline DTensor &DTensor::operator+=(const DTensor &rhs) { const double alpha = 1.; gpuErrChk( - cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, - 1, m_d_data, 1)); + cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, + 1, m_d_data, 1)); return *this; } @@ -658,8 +687,8 @@ template<> inline DTensor &DTensor::operator+=(const DTensor &rhs) { const float alpha = 1.; gpuErrChk( - cublasSaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, - 1, m_d_data, 1)); + cublasSaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, + 1, m_d_data, 1)); return *this; } @@ -675,8 +704,8 @@ template<> inline DTensor &DTensor::operator-=(const DTensor &rhs) { const double alpha = -1.; gpuErrChk( - cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, - 1, m_d_data, 1)); + cublasDaxpy(Session::getInstance().cuBlasHandle(), m_numRows * m_numCols * m_numMats, &alpha, rhs.m_d_data, + 1, m_d_data, 1)); return *this; } @@ -987,7 +1016,7 @@ public: for (size_t i = 0; i < m_rank->numMats(); i++) { DTensor Si(*m_S, 2, i, i); DTensor rankI(*m_rank, 2, i, i); - k_countNonzeroSingularValues<<>>(Si.raw(), numElS, + k_countNonzeroSingularValues<<>>(Si.raw(), numElS, rankI.raw(), epsilon); } return *m_rank; @@ -1021,17 +1050,17 @@ inline bool Svd::factorise() { if (m_computeU) Ui = std::make_unique>(*m_U, 2, i, i); gpuErrChk( - cusolverDnDgesvd(Session::getInstance().cuSolverHandle(), - (m_computeU) ? 'A' : 'N', 'A', - m, n, - Ai.raw(), m, - Si.raw(), - (m_computeU) ? Ui->raw() : nullptr, m, - Vtri.raw(), n, - m_workspace->raw(), - m_lwork, - nullptr, // rwork (used only if SVD fails) - m_info->raw())); + cusolverDnDgesvd(Session::getInstance().cuSolverHandle(), + (m_computeU) ? 'A' : 'N', 'A', + m, n, + Ai.raw(), m, + Si.raw(), + (m_computeU) ? Ui->raw() : nullptr, m, + Vtri.raw(), n, + m_workspace->raw(), + m_lwork, + nullptr, // rwork (used only if SVD fails) + m_info->raw())); info = info && ((*m_info)(0, 0, 0) == 0); } return info; @@ -1051,17 +1080,17 @@ inline bool Svd::factorise() { if (m_computeU) Ui = std::make_unique>(*m_U, 2, i, i); gpuErrChk( - cusolverDnSgesvd(Session::getInstance().cuSolverHandle(), - (m_computeU) ? 'A' : 'N', 'A', - m, n, - Ai.raw(), m, - Si.raw(), - (m_computeU) ? Ui->raw() : nullptr, m, - Vtri.raw(), n, - m_workspace->raw(), - m_lwork, - nullptr, // rwork (used only if SVD fails) - m_info->raw())); + cusolverDnSgesvd(Session::getInstance().cuSolverHandle(), + (m_computeU) ? 'A' : 'N', 'A', + m, n, + Ai.raw(), m, + Si.raw(), + (m_computeU) ? Ui->raw() : nullptr, m, + Vtri.raw(), n, + m_workspace->raw(), + m_lwork, + nullptr, // rwork (used only if SVD fails) + m_info->raw())); info = info && ((*m_info)(0, 0, 0) == 0); } return info; @@ -1304,7 +1333,7 @@ public: * @param A either matrices to be factorised or lower-triangular Cholesky decomposition matrices * @param factorised whether A is the original matrices or the factorised ones (default=original matrices) */ - CholeskyBatchFactoriser(DTensor &A, bool factorised=false) : m_factorisationDone(factorised) { + CholeskyBatchFactoriser(DTensor &A, bool factorised = false) : m_factorisationDone(factorised) { if (A.numRows() != A.numCols()) throw std::invalid_argument("[CholeskyBatch] A must be square"); m_matrix = &A; m_numRows = A.numRows();