From f3cd82ef2a58c9e803ef1085d4666ed0d1633124 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 18 Mar 2025 17:53:40 +0100 Subject: [PATCH 1/4] pick the changes for merge-path Csr spmv from https://github.com/ginkgo-project/ginkgo/pull/1497 Co-authored-by: Luka Stanisic --- contributors.txt | 1 + omp/matrix/csr_kernels.cpp | 137 +++++++++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 4 deletions(-) diff --git a/contributors.txt b/contributors.txt index fc7a831b468..4a4964123fa 100644 --- a/contributors.txt +++ b/contributors.txt @@ -25,4 +25,5 @@ Nguyen Phuong University of Tennessee, Knoxville Olenik Gregor HPSim Ribizel Tobias Karlsruhe Institute of Technology Riemer Lukas Karlsruhe Institute of Technology +Luka Stanisic Huawei Technologies Duesseldorf GmbH Tsai Yuhsiang National Taiwan University diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 9c626d31004..290593bd619 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -42,12 +42,125 @@ namespace omp { namespace csr { +/** + * Computes the begin offsets into A and B for the specific diagonal + */ +template +inline void merge_path_search( + const int diagonal, ///< [in]The diagonal to search + const IndexType* A, ///< [in]List A + const int A_len, ///< [in]Length of A + const int B_len, ///< [in]Length of B + int& path_coordinate_x, ///< [out] (x) coordinate where diagonal intersects + ///< the merge path + int& path_coordinate_y) ///< [out] (y) coordinate where diagonal intersects + ///< the merge path +{ + auto x_min = std::max(diagonal - B_len, 0); + auto x_max = std::min(diagonal, A_len); + + while (x_min < x_max) { + auto x_pivot = x_min + ((x_max - x_min) / 2); + if (A[x_pivot] <= (diagonal - x_pivot - 1)) { + x_min = x_pivot + 1; // Contract range up A (down B) + } else { + x_max = x_pivot; // Contract range down A (up B) + } + } + + path_coordinate_x = std::min(x_min, A_len); + path_coordinate_y = diagonal - x_min; +} + template -void spmv(std::shared_ptr exec, - const matrix::Csr* a, - const matrix::Dense* b, - matrix::Dense* c) +void merge_spmv(std::shared_ptr exec, + const matrix::Csr* a, + const matrix::Dense* b, + matrix::Dense* c) +{ + using arithmetic_type = + highest_precision; + + auto row_ptrs = a->get_const_row_ptrs(); + auto col_idxs = a->get_const_col_idxs(); + + const auto a_vals = + acc::helper::build_const_rrm_accessor(a); + const auto b_vals = + acc::helper::build_const_rrm_accessor(b); + auto c_vals = acc::helper::build_rrm_accessor(c); + + // Merge-SpMV variables + const auto num_rows = a->get_size()[0]; + const auto nnz = a->get_num_stored_elements(); + const size_type num_threads = omp_get_max_threads(); + const IndexType* row_end_offsets = + row_ptrs + 1; // Merge list A: row end offsets + const auto num_merge_items = num_rows + nnz; // Merge path total length + const auto items_per_thread = (num_merge_items + num_threads - 1) / + num_threads; // Merge items per thread + array row_carry_out{exec, num_threads}; + array value_carry_out{exec, num_threads}; + auto row_carry_out_ptr = row_carry_out.get_data(); + auto value_carry_out_ptr = value_carry_out.get_data(); + + for (size_type j = 0; j < c->get_size()[1]; ++j) { +#pragma omp parallel for schedule(static) + for (size_type tid = 0; tid < num_threads; tid++) { + const auto start_diagonal = + std::min(items_per_thread * tid, num_merge_items); + const auto end_diagonal = + std::min(start_diagonal + items_per_thread, num_merge_items); + int thread_coord_x, thread_coord_y, thread_coord_end_x, + thread_coord_end_y; + + merge_path_search(start_diagonal, row_end_offsets, num_rows, nnz, + thread_coord_x, thread_coord_y); + merge_path_search(end_diagonal, row_end_offsets, num_rows, nnz, + thread_coord_end_x, thread_coord_end_y); + + // Consume merge items, whole rows first + for (; thread_coord_x < thread_coord_end_x; thread_coord_x++) { + auto sum = zero(); + for (; thread_coord_y < row_end_offsets[thread_coord_x]; + thread_coord_y++) { + arithmetic_type val = a_vals(thread_coord_y); + auto col = col_idxs[thread_coord_y]; + sum += val * b_vals(col, j); + } + c_vals(thread_coord_x, j) = sum; + } + + // Consume partial portion of thread's last row + auto sum = zero(); + for (; thread_coord_y < thread_coord_end_y; thread_coord_y++) { + arithmetic_type val = a_vals(thread_coord_y); + auto col = col_idxs[thread_coord_y]; + sum += val * b_vals(col, j); + } + + // Save carry-outs + row_carry_out_ptr[tid] = thread_coord_end_x; + value_carry_out_ptr[tid] = sum; + } + + // Carry-out fix-up (rows spanning multiple threads) + for (int tid = 0; tid < num_threads - 1; tid++) { + if (row_carry_out_ptr[tid] < num_rows) { + c_vals(row_carry_out_ptr[tid], j) += value_carry_out_ptr[tid]; + } + } + } +} + + +template +void classical_spmv(std::shared_ptr exec, + const matrix::Csr* a, + const matrix::Dense* b, + matrix::Dense* c) { using arithmetic_type = highest_precision; @@ -77,6 +190,22 @@ void spmv(std::shared_ptr exec, } } +template +void spmv(std::shared_ptr exec, + const matrix::Csr* a, + const matrix::Dense* b, + matrix::Dense* c) +{ + if (c->get_size()[0] == 0 || c->get_size()[1] == 0) { + // empty output: nothing to do + } else if (a->get_strategy()->get_name() == "merge_path") { + merge_spmv(exec, a, b, c); + } else { + classical_spmv(exec, a, b, c); + } +} + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( GKO_DECLARE_CSR_SPMV_KERNEL); From 4abb71a28bdff007d692e674fabba8c062582a6a Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 18 Mar 2025 18:06:09 +0100 Subject: [PATCH 2/4] apply the comments on the pr https://github.com/ginkgo-project/ginkgo/pull/1497 --- omp/matrix/csr_kernels.cpp | 46 ++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 290593bd619..25f92316fd6 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -47,17 +47,17 @@ namespace csr { */ template inline void merge_path_search( - const int diagonal, ///< [in]The diagonal to search - const IndexType* A, ///< [in]List A - const int A_len, ///< [in]Length of A - const int B_len, ///< [in]Length of B - int& path_coordinate_x, ///< [out] (x) coordinate where diagonal intersects - ///< the merge path - int& path_coordinate_y) ///< [out] (y) coordinate where diagonal intersects - ///< the merge path + const IndexType diagonal, ///< [in]The diagonal to search + const IndexType* A, ///< [in]List A + const IndexType a_len, ///< [in]Length of A + const IndexType b_len, ///< [in]Length of B + IndexType& path_coordinate_x, ///< [out] (x) coordinate where diagonal + ///< intersects the merge path + IndexType& path_coordinate_y) ///< [out] (y) coordinate where diagonal + ///< intersects the merge path { - auto x_min = std::max(diagonal - B_len, 0); - auto x_max = std::min(diagonal, A_len); + auto x_min = std::max(diagonal - b_len, zero()); + auto x_max = std::min(diagonal, a_len); while (x_min < x_max) { auto x_pivot = x_min + ((x_max - x_min) / 2); @@ -68,10 +68,11 @@ inline void merge_path_search( } } - path_coordinate_x = std::min(x_min, A_len); + path_coordinate_x = std::min(x_min, a_len); path_coordinate_y = diagonal - x_min; } + template void merge_spmv(std::shared_ptr exec, @@ -98,22 +99,27 @@ void merge_spmv(std::shared_ptr exec, const IndexType* row_end_offsets = row_ptrs + 1; // Merge list A: row end offsets const auto num_merge_items = num_rows + nnz; // Merge path total length - const auto items_per_thread = (num_merge_items + num_threads - 1) / - num_threads; // Merge items per thread + const auto items_per_thread = + ceildiv(num_merge_items, num_threads); // Merge items per thread array row_carry_out{exec, num_threads}; array value_carry_out{exec, num_threads}; auto row_carry_out_ptr = row_carry_out.get_data(); auto value_carry_out_ptr = value_carry_out.get_data(); for (size_type j = 0; j < c->get_size()[1]; ++j) { + // TODO: It uses static from the observation of the previous + // experiments. Check it with different system and different kinds of + // schedule. #pragma omp parallel for schedule(static) for (size_type tid = 0; tid < num_threads; tid++) { - const auto start_diagonal = - std::min(items_per_thread * tid, num_merge_items); - const auto end_diagonal = - std::min(start_diagonal + items_per_thread, num_merge_items); - int thread_coord_x, thread_coord_y, thread_coord_end_x, - thread_coord_end_y; + const auto start_diagonal = static_cast( + std::min(items_per_thread * tid, num_merge_items)); + const auto end_diagonal = static_cast( + std::min(start_diagonal + items_per_thread, num_merge_items)); + IndexType thread_coord_x; + IndexType thread_coord_y; + IndexType thread_coord_end_x; + IndexType thread_coord_end_y; merge_path_search(start_diagonal, row_end_offsets, num_rows, nnz, thread_coord_x, thread_coord_y); @@ -146,6 +152,8 @@ void merge_spmv(std::shared_ptr exec, } // Carry-out fix-up (rows spanning multiple threads) + // The last thread does not carry out partial result becaust it must + // compute the result till the last row end. for (int tid = 0; tid < num_threads - 1; tid++) { if (row_carry_out_ptr[tid] < num_rows) { c_vals(row_carry_out_ptr[tid], j) += value_carry_out_ptr[tid]; From 257809854ddb53ae7a094264528de68fe41b5973 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 19 Mar 2025 11:52:04 +0100 Subject: [PATCH 3/4] add merge_path to advanced spmv and enable the test for openmp --- omp/matrix/csr_kernels.cpp | 145 ++++++++++++++++++------------------- test/matrix/matrix.cpp | 15 +++- 2 files changed, 84 insertions(+), 76 deletions(-) diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 25f92316fd6..4f78821a939 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -44,41 +44,44 @@ namespace csr { /** * Computes the begin offsets into A and B for the specific diagonal + * + * @param diagonal the diagonal to search + * @param end_row_offsets the ending of row offsets of A + * @param a_len the length of A (the number of rows) + * @param b_len the length of B (the number of stored elements) + * + * @return a pair which contains (x, y) coordinate where diagonal intersects the + * merge path */ template -inline void merge_path_search( - const IndexType diagonal, ///< [in]The diagonal to search - const IndexType* A, ///< [in]List A - const IndexType a_len, ///< [in]Length of A - const IndexType b_len, ///< [in]Length of B - IndexType& path_coordinate_x, ///< [out] (x) coordinate where diagonal - ///< intersects the merge path - IndexType& path_coordinate_y) ///< [out] (y) coordinate where diagonal - ///< intersects the merge path +inline std::pair merge_path_search( + const IndexType diagonal, const IndexType* end_row_offsets, + const IndexType a_len, const IndexType b_len) { auto x_min = std::max(diagonal - b_len, zero()); auto x_max = std::min(diagonal, a_len); while (x_min < x_max) { auto x_pivot = x_min + ((x_max - x_min) / 2); - if (A[x_pivot] <= (diagonal - x_pivot - 1)) { + if (end_row_offsets[x_pivot] <= (diagonal - x_pivot - 1)) { x_min = x_pivot + 1; // Contract range up A (down B) } else { x_max = x_pivot; // Contract range down A (up B) } } - path_coordinate_x = std::min(x_min, a_len); - path_coordinate_y = diagonal - x_min; + return std::make_pair(std::min(x_min, a_len), diagonal - x_min); } template + typename OutputValueType, typename IndexType, typename AlphaOp, + typename BetaOp> void merge_spmv(std::shared_ptr exec, const matrix::Csr* a, const matrix::Dense* b, - matrix::Dense* c) + matrix::Dense* c, AlphaOp alpha_op, + BetaOp beta_op) { using arithmetic_type = highest_precision; @@ -93,68 +96,65 @@ void merge_spmv(std::shared_ptr exec, auto c_vals = acc::helper::build_rrm_accessor(c); // Merge-SpMV variables - const auto num_rows = a->get_size()[0]; - const auto nnz = a->get_num_stored_elements(); - const size_type num_threads = omp_get_max_threads(); - const IndexType* row_end_offsets = - row_ptrs + 1; // Merge list A: row end offsets - const auto num_merge_items = num_rows + nnz; // Merge path total length + const auto num_rows = static_cast(a->get_size()[0]); + const auto nnz = static_cast(a->get_num_stored_elements()); + const auto num_threads = static_cast(omp_get_max_threads()); + // Merge list A: row end offsets + const IndexType* row_end_offsets = row_ptrs + 1; + // Merge path total length + const auto num_merge_items = num_rows + nnz; + // Merge items per thread const auto items_per_thread = - ceildiv(num_merge_items, num_threads); // Merge items per thread + static_cast(ceildiv(num_merge_items, num_threads)); array row_carry_out{exec, num_threads}; array value_carry_out{exec, num_threads}; auto row_carry_out_ptr = row_carry_out.get_data(); auto value_carry_out_ptr = value_carry_out.get_data(); + // TODO: parallelize with number of cols, too. for (size_type j = 0; j < c->get_size()[1]; ++j) { // TODO: It uses static from the observation of the previous // experiments. Check it with different system and different kinds of // schedule. #pragma omp parallel for schedule(static) - for (size_type tid = 0; tid < num_threads; tid++) { - const auto start_diagonal = static_cast( - std::min(items_per_thread * tid, num_merge_items)); - const auto end_diagonal = static_cast( - std::min(start_diagonal + items_per_thread, num_merge_items)); - IndexType thread_coord_x; - IndexType thread_coord_y; - IndexType thread_coord_end_x; - IndexType thread_coord_end_y; - - merge_path_search(start_diagonal, row_end_offsets, num_rows, nnz, - thread_coord_x, thread_coord_y); - merge_path_search(end_diagonal, row_end_offsets, num_rows, nnz, - thread_coord_end_x, thread_coord_end_y); - + for (IndexType tid = 0; tid < num_threads; tid++) { + const auto start_diagonal = + std::min(items_per_thread * tid, num_merge_items); + const auto end_diagonal = + std::min(start_diagonal + items_per_thread, num_merge_items); + + auto [x, y] = merge_path_search(start_diagonal, row_end_offsets, + num_rows, nnz); + auto [end_x, end_y] = + merge_path_search(end_diagonal, row_end_offsets, num_rows, nnz); // Consume merge items, whole rows first - for (; thread_coord_x < thread_coord_end_x; thread_coord_x++) { + for (; x < end_x; x++) { auto sum = zero(); - for (; thread_coord_y < row_end_offsets[thread_coord_x]; - thread_coord_y++) { - arithmetic_type val = a_vals(thread_coord_y); - auto col = col_idxs[thread_coord_y]; + for (; y < row_end_offsets[x]; y++) { + arithmetic_type val = a_vals(y); + auto col = col_idxs[y]; sum += val * b_vals(col, j); } - c_vals(thread_coord_x, j) = sum; + c_vals(x, j) = alpha_op(sum) + beta_op(c_vals(x, j)); } // Consume partial portion of thread's last row auto sum = zero(); - for (; thread_coord_y < thread_coord_end_y; thread_coord_y++) { - arithmetic_type val = a_vals(thread_coord_y); - auto col = col_idxs[thread_coord_y]; + for (; y < end_y; y++) { + arithmetic_type val = a_vals(y); + auto col = col_idxs[y]; sum += val * b_vals(col, j); } // Save carry-outs - row_carry_out_ptr[tid] = thread_coord_end_x; - value_carry_out_ptr[tid] = sum; + row_carry_out_ptr[tid] = end_x; + value_carry_out_ptr[tid] = alpha_op(sum); } // Carry-out fix-up (rows spanning multiple threads) // The last thread does not carry out partial result becaust it must // compute the result till the last row end. - for (int tid = 0; tid < num_threads - 1; tid++) { + for (IndexType tid = 0; tid < num_threads - 1; tid++) { if (row_carry_out_ptr[tid] < num_rows) { c_vals(row_carry_out_ptr[tid], j) += value_carry_out_ptr[tid]; } @@ -164,11 +164,11 @@ void merge_spmv(std::shared_ptr exec, template + typename OutputValueType, typename IndexType, typename Function> void classical_spmv(std::shared_ptr exec, const matrix::Csr* a, const matrix::Dense* b, - matrix::Dense* c) + matrix::Dense* c, Function lambda) { using arithmetic_type = highest_precision; @@ -193,7 +193,7 @@ void classical_spmv(std::shared_ptr exec, sum += val * b_vals(col, j); } - c_vals(row, j) = sum; + c_vals(row, j) = lambda(sum, c_vals(row, j)); } } } @@ -205,12 +205,16 @@ void spmv(std::shared_ptr exec, const matrix::Dense* b, matrix::Dense* c) { + using arithmetic_type = + highest_precision; if (c->get_size()[0] == 0 || c->get_size()[1] == 0) { // empty output: nothing to do } else if (a->get_strategy()->get_name() == "merge_path") { - merge_spmv(exec, a, b, c); + merge_spmv( + exec, a, b, c, [](auto val) { return val; }, + [](auto) { return zero(); }); } else { - classical_spmv(exec, a, b, c); + classical_spmv(exec, a, b, c, [](auto sum, auto) { return sum; }); } } @@ -229,29 +233,22 @@ void advanced_spmv(std::shared_ptr exec, { using arithmetic_type = highest_precision; - - auto row_ptrs = a->get_const_row_ptrs(); - auto col_idxs = a->get_const_col_idxs(); auto valpha = static_cast(alpha->at(0, 0)); auto vbeta = static_cast(beta->at(0, 0)); - - const auto a_vals = - acc::helper::build_const_rrm_accessor(a); - const auto b_vals = - acc::helper::build_const_rrm_accessor(b); - auto c_vals = acc::helper::build_rrm_accessor(c); -#pragma omp parallel for - for (size_type row = 0; row < a->get_size()[0]; ++row) { - for (size_type j = 0; j < c->get_size()[1]; ++j) { - auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta; - for (size_type k = row_ptrs[row]; - k < static_cast(row_ptrs[row + 1]); ++k) { - arithmetic_type val = a_vals(k); - auto col = col_idxs[k]; - sum += valpha * val * b_vals(col, j); - } - c_vals(row, j) = sum; - } + if (c->get_size()[0] == 0 || c->get_size()[1] == 0) { + // empty output: nothing to do + } else if (a->get_strategy()->get_name() == "merge_path") { + merge_spmv( + exec, a, b, c, [valpha](auto val) { return valpha * val; }, + [vbeta](auto val) { + return is_zero(vbeta) ? zero(vbeta) : val * vbeta; + }); + } else { + classical_spmv(exec, a, b, c, [valpha, vbeta](auto sum, auto orig_val) { + auto scaled_orig_val = + is_zero(vbeta) ? zero(vbeta) : orig_val * vbeta; + return valpha * sum + scaled_orig_val; + }); } } diff --git a/test/matrix/matrix.cpp b/test/matrix/matrix.cpp index 2c45f841628..f016365bc46 100644 --- a/test/matrix/matrix.cpp +++ b/test/matrix/matrix.cpp @@ -130,7 +130,7 @@ struct CsrWithDefaultStrategy : CsrBase { #if defined(GKO_COMPILING_CUDA) || defined(GKO_COMPILING_HIP) || \ - defined(GKO_COMPILING_DPCPP) + defined(GKO_COMPILING_DPCPP) || defined(GKO_COMPILING_OMP) struct CsrWithClassicalStrategy : CsrBase { @@ -177,6 +177,14 @@ struct CsrWithMergePathStrategy : CsrBase { } }; + +#endif + + +#if defined(GKO_COMPILING_CUDA) || defined(GKO_COMPILING_HIP) || \ + defined(GKO_COMPILING_DPCPP) + + struct CsrWithSparselibStrategy : CsrBase { static std::unique_ptr create( std::shared_ptr exec, gko::dim<2> size) @@ -827,8 +835,11 @@ class Matrix : public CommonTestFixture { using MatrixTypes = ::testing::Types< DenseWithDefaultStride, DenseWithCustomStride, Coo, CsrWithDefaultStrategy, #if defined(GKO_COMPILING_CUDA) || defined(GKO_COMPILING_HIP) || \ - defined(GKO_COMPILING_DPCPP) + defined(GKO_COMPILING_DPCPP) || defined(GKO_COMPILING_OMP) CsrWithClassicalStrategy, CsrWithMergePathStrategy, +#endif +#if defined(GKO_COMPILING_CUDA) || defined(GKO_COMPILING_HIP) || \ + defined(GKO_COMPILING_DPCPP) CsrWithSparselibStrategy, CsrWithLoadBalanceStrategy, CsrWithAutomaticalStrategy, #endif From 0e09c20baea8ffedc625e1fc2799a93367fe868d Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 26 May 2025 12:50:42 +0200 Subject: [PATCH 4/4] improve documentation and naming Co-authored-by: Marcel Koch --- omp/matrix/csr_kernels.cpp | 51 +++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 4f78821a939..3710ed5bd7a 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -43,10 +43,15 @@ namespace csr { /** - * Computes the begin offsets into A and B for the specific diagonal + * Computes the begin offsets into A (the number of rows) and B (the number of + * stored element), which intersect the diagonal line. The diagonal line is + * formed by the reaching point with the `diagonal` step from the starting + * point. * - * @param diagonal the diagonal to search - * @param end_row_offsets the ending of row offsets of A + * @param diagonal the diagonal line to search + * @param row_end_ptrs the pointer to the ending of row offset of A. + * row_end_ptrs[i] gives the ending in the value/column + * array for row_i. * @param a_len the length of A (the number of rows) * @param b_len the length of B (the number of stored elements) * @@ -55,7 +60,7 @@ namespace csr { */ template inline std::pair merge_path_search( - const IndexType diagonal, const IndexType* end_row_offsets, + const IndexType diagonal, const IndexType* row_end_ptrs, const IndexType a_len, const IndexType b_len) { auto x_min = std::max(diagonal - b_len, zero()); @@ -63,7 +68,7 @@ inline std::pair merge_path_search( while (x_min < x_max) { auto x_pivot = x_min + ((x_max - x_min) / 2); - if (end_row_offsets[x_pivot] <= (diagonal - x_pivot - 1)) { + if (row_end_ptrs[x_pivot] <= (diagonal - x_pivot - 1)) { x_min = x_pivot + 1; // Contract range up A (down B) } else { x_max = x_pivot; // Contract range down A (up B) @@ -99,17 +104,17 @@ void merge_spmv(std::shared_ptr exec, const auto num_rows = static_cast(a->get_size()[0]); const auto nnz = static_cast(a->get_num_stored_elements()); const auto num_threads = static_cast(omp_get_max_threads()); - // Merge list A: row end offsets - const IndexType* row_end_offsets = row_ptrs + 1; + // Merge list A: row end ptr + const IndexType* row_end_ptrs = row_ptrs + 1; // Merge path total length const auto num_merge_items = num_rows + nnz; // Merge items per thread const auto items_per_thread = static_cast(ceildiv(num_merge_items, num_threads)); - array row_carry_out{exec, num_threads}; - array value_carry_out{exec, num_threads}; - auto row_carry_out_ptr = row_carry_out.get_data(); - auto value_carry_out_ptr = value_carry_out.get_data(); + array row_carry_over(exec, num_threads); + array value_carry_over(exec, num_threads); + auto row_carry_over_ptr = row_carry_over.get_data(); + auto value_carry_over_ptr = value_carry_over.get_data(); // TODO: parallelize with number of cols, too. for (size_type j = 0; j < c->get_size()[1]; ++j) { @@ -123,14 +128,14 @@ void merge_spmv(std::shared_ptr exec, const auto end_diagonal = std::min(start_diagonal + items_per_thread, num_merge_items); - auto [x, y] = merge_path_search(start_diagonal, row_end_offsets, - num_rows, nnz); + auto [x, y] = + merge_path_search(start_diagonal, row_end_ptrs, num_rows, nnz); auto [end_x, end_y] = - merge_path_search(end_diagonal, row_end_offsets, num_rows, nnz); + merge_path_search(end_diagonal, row_end_ptrs, num_rows, nnz); // Consume merge items, whole rows first for (; x < end_x; x++) { auto sum = zero(); - for (; y < row_end_offsets[x]; y++) { + for (; y < row_end_ptrs[x]; y++) { arithmetic_type val = a_vals(y); auto col = col_idxs[y]; sum += val * b_vals(col, j); @@ -146,17 +151,17 @@ void merge_spmv(std::shared_ptr exec, sum += val * b_vals(col, j); } - // Save carry-outs - row_carry_out_ptr[tid] = end_x; - value_carry_out_ptr[tid] = alpha_op(sum); + // Save carry over + row_carry_over_ptr[tid] = end_x; + value_carry_over_ptr[tid] = alpha_op(sum); } - // Carry-out fix-up (rows spanning multiple threads) - // The last thread does not carry out partial result becaust it must - // compute the result till the last row end. + // Carry over fix-up (rows spanning multiple threads) + // The carry over from thread `tid` to `tid + 1` is added by the thread + // `tid`, thus the last thread has no work. for (IndexType tid = 0; tid < num_threads - 1; tid++) { - if (row_carry_out_ptr[tid] < num_rows) { - c_vals(row_carry_out_ptr[tid], j) += value_carry_out_ptr[tid]; + if (row_carry_over_ptr[tid] < num_rows) { + c_vals(row_carry_over_ptr[tid], j) += value_carry_over_ptr[tid]; } } }