Skip to content

Commit ac991af

Browse files
authored
Add a member function to BsrMatrix to convert to Crs (#2809)
* Add a member function to BsrMatrix to conver to Crs Signed-off-by: James Foucar <jgfouca@sandia.gov> * Update docs Signed-off-by: James Foucar <jgfouca@sandia.gov> * Fix doc Signed-off-by: James Foucar <jgfouca@sandia.gov> * Fix mistakes in bsr test Signed-off-by: James Foucar <jgfouca@sandia.gov> * Fix veclen max Signed-off-by: James Foucar <jgfouca@sandia.gov> --------- Signed-off-by: James Foucar <jgfouca@sandia.gov>
1 parent d818ac6 commit ac991af

File tree

3 files changed

+202
-28
lines changed

3 files changed

+202
-28
lines changed

docs/source/API/sparse/bsr_matrix.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,25 @@ unmanaged_block_const
321321
Return a const view of the i-th block in the matrix.
322322

323323

324+
convertToCrs
325+
^^^^^^^^^^^^
326+
327+
.. code:: cppkokkos
328+
329+
template <typename CrsMatrixType = KokkosSparse::CrsMatrix<ScalarType, OrdinalType, Device, MemoryTraits, SizeType>>
330+
CrsMatrixType convertToCrs() const;
331+
332+
Convert the Bsr into a CrsMatrix
333+
334+
The default return type will be a CrsMatrix with all the same template arguments
335+
as this Bsr, but you can provide your own type if needed. The only requirement
336+
is that the execution spaces match.
337+
338+
This is a host function.
339+
324340
Example
325341
=======
326342

327343
.. literalinclude:: ../../../../example/wiki/sparse/KokkosSparse_wiki_bsrmatrix_2.cpp
328344
:language: c++
329345
:lines: 16-
330-

sparse/src/KokkosSparse_BsrMatrix.hpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,96 @@ class BsrMatrix {
912912
return const_block_type(&values(i * blockDim_ * blockDim_), blockDim_, blockDim_);
913913
}
914914

915+
/// \brief Convert the Bsr into a CrsMatrix
916+
///
917+
/// The default return type will be a CrsMatrix with all the same template arguments
918+
/// as this Bsr, but you can provide your own type if needed. The only requirement
919+
/// is that the execution spaces match.
920+
///
921+
/// This is a host function.
922+
///
923+
template <typename CrsMatrixType = KokkosSparse::CrsMatrix<ScalarType, OrdinalType, Device, MemoryTraits, SizeType>>
924+
CrsMatrixType convertToCrs() const {
925+
using crs_size_t = typename CrsMatrixType::size_type;
926+
using crs_rowmap_t = typename CrsMatrixType::row_map_type::non_const_type;
927+
using crs_entries_t = typename CrsMatrixType::index_type::non_const_type;
928+
using crs_values_t = typename CrsMatrixType::values_type::non_const_type;
929+
using crs_exe_space_t = typename CrsMatrixType::execution_space;
930+
using policy_t = typename Kokkos::TeamPolicy<execution_space>;
931+
using member_t = typename policy_t::member_type;
932+
933+
static_assert(std::is_same_v<crs_exe_space_t, execution_space>,
934+
"We do not currently support converting a BsrMatrix to a CsrMatrix with a different execution space");
935+
936+
// Get size/dimension info from this Bsr. We will use crs_size_t for all int types
937+
const crs_size_t blockDim = this->blockDim();
938+
const crs_size_t blockSize = blockDim * blockDim;
939+
const crs_size_t numBlockRows = numRows();
940+
const crs_size_t numBlockCols = numCols();
941+
const crs_size_t numBlockEntries = nnz();
942+
943+
// Get graph and values from this Bsr
944+
const auto blockRowMap = graph.row_map;
945+
const auto blockEntries = graph.entries;
946+
const auto blockValues = values;
947+
948+
// Compute Csr row/col/entry sizes by multiplying Bsr sizes by block dimension
949+
const crs_size_t numCrsRows = numBlockRows * blockDim;
950+
const crs_size_t numCrsCols = numBlockCols * blockDim;
951+
const crs_size_t numCrsEntries = numBlockEntries * blockSize;
952+
953+
// Allocate CrsMatrix views
954+
crs_rowmap_t crsRowMap("crsRowMap", numCrsRows + 1);
955+
crs_entries_t crsEntries("crsEntries", numCrsEntries);
956+
crs_values_t crsValues("crsValues", numCrsEntries);
957+
958+
// Create the policy, we have 3 levels of parallelism available in the algorithm
959+
const crs_size_t maxvec = policy_t::vector_length_max();
960+
const crs_size_t veclen = (blockDim <= maxvec) ? blockDim : maxvec;
961+
policy_t policy(numBlockRows, Kokkos::AUTO(), veclen);
962+
963+
// Fill CrsMatrix row map, entries, and values
964+
Kokkos::parallel_for(
965+
"ConvertBsrToCrs", policy, KOKKOS_LAMBDA(const member_t& team) {
966+
const crs_size_t blockRow = team.league_rank();
967+
const crs_size_t blockRowStart = blockRowMap(blockRow);
968+
const crs_size_t blockRowEnd = blockRowMap(blockRow + 1);
969+
const crs_size_t blockRowCount = blockRowEnd - blockRowStart;
970+
971+
// Iterate over block entries in this row.
972+
Kokkos::parallel_for(
973+
Kokkos::TeamThreadRange(team, blockRowStart, blockRowEnd), [&](const crs_size_t& blockNnz) {
974+
const crs_size_t blockCol = blockEntries(blockNnz);
975+
const crs_size_t blockNum = blockNnz - blockRowStart;
976+
977+
// Iterate over block dim to get the unblocked rows
978+
Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, blockDim), [&](const crs_size_t& blockRowOffset) {
979+
const crs_size_t crsRow = blockRow * blockDim + blockRowOffset;
980+
// Each unblocked row has blockRowCount * blockDim items
981+
const crs_size_t crsRowStart = blockRowStart * blockSize + blockRowCount * blockDim * blockRowOffset;
982+
crsRowMap(crsRow) = crsRowStart;
983+
984+
// Iterate over block dim to get the unblocked cols
985+
for (crs_size_t blockColOffset = 0; blockColOffset < blockDim; ++blockColOffset) {
986+
const crs_size_t crsCol = blockCol * blockDim + blockColOffset;
987+
const crs_size_t crsNnz = crsRowStart + blockNum * blockDim + blockColOffset;
988+
crsEntries(crsNnz) = crsCol;
989+
crsValues(crsNnz) = blockValues(blockNnz * blockSize + blockRowOffset * blockDim + blockColOffset);
990+
}
991+
});
992+
});
993+
994+
// Finalize CrsMatrix row map
995+
if (blockRow == numBlockRows - 1) {
996+
crsRowMap(numCrsRows) = blockRowMap(numBlockRows) * blockSize;
997+
}
998+
});
999+
1000+
// Construct CrsMatrix
1001+
return CrsMatrixType("convertedFromBsrMatrix", numCrsRows, numCrsCols, crsEntries.extent(0), crsValues, crsRowMap,
1002+
crsEntries);
1003+
}
1004+
9151005
protected:
9161006
enum class valueOperation { ADD, ASSIGN };
9171007

sparse/unit_test/Test_Sparse_BsrMatrix.hpp

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,14 @@ struct TestFunctor {
189189
auto row_ptr = iblockrow.local_row_in_block(blk, lrow);
190190
for (auto lcol = 0; lcol < A.blockDim(); ++lcol) {
191191
auto entry = iblockrow.local_block_value(blk, lrow, lcol);
192-
// std::cout << "check0: " << ( entry == row_ptr[lcol] );
193-
// std::cout << "check1: " << ( entry == view_blk(lrow,lcol) );
194-
check0 = check0 && (entry == row_ptr[lcol]);
195-
check1 = check1 && (entry == view_blk(lrow, lcol));
192+
check0 &= (entry == row_ptr[lcol]);
193+
check1 &= (entry == view_blk(lrow, lcol));
196194
} // end local col in row
197195
} // end local row in blk
198196
} // end blk
199197
}
200-
d_results(0) = check0;
201-
d_results(1) = check1;
198+
d_results(0) &= check0;
199+
d_results(1) &= check1;
202200

203201
// Test BsrRowViewConst
204202
{
@@ -210,16 +208,14 @@ struct TestFunctor {
210208
auto row_ptr = iblockrow.local_row_in_block(blk, lrow);
211209
for (auto lcol = 0; lcol < A.blockDim(); ++lcol) {
212210
auto entry = iblockrow.local_block_value(blk, lrow, lcol);
213-
check2 = check2 && (entry == row_ptr[lcol]);
214-
check3 = check3 && (entry == view_blk(lrow, lcol));
211+
check2 &= (entry == row_ptr[lcol]);
212+
check3 &= (entry == view_blk(lrow, lcol));
215213
} // end local col in row
216214
} // end local row in blk
217215
} // end blk
218216
}
219-
d_results(0) = check0;
220-
d_results(1) = check1;
221-
d_results(2) = check2;
222-
d_results(3) = check3;
217+
d_results(2) &= check2;
218+
d_results(3) &= check3;
223219
} // end for blk rows
224220

225221
// Test sumIntoValues
@@ -243,14 +239,14 @@ struct TestFunctor {
243239
auto row_ptr = iblockrow.local_row_in_block(relBlk, lrow);
244240
for (auto lcol = 0; lcol < A.blockDim(); ++lcol) {
245241
auto entry = iblockrow.local_block_value(relBlk, lrow, lcol);
246-
check0 = check0 && (entry == row_ptr[lcol]);
247-
check1 = check1 && (entry == view_blk(lrow, lcol));
248-
check2 = check2 && (entry == result[lrow * A.blockDim() + lcol]);
242+
check0 &= (entry == row_ptr[lcol]);
243+
check1 &= (entry == view_blk(lrow, lcol));
244+
check2 &= (entry == result[lrow * A.blockDim() + lcol]);
249245
} // end local col in row
250246
} // end local row in blk
251-
d_results(4) = check0;
252-
d_results(5) = check1;
253-
d_results(6) = check2;
247+
d_results(4) &= check0;
248+
d_results(5) &= check1;
249+
d_results(6) &= check2;
254250
}
255251

256252
// Test replaceValues
@@ -273,19 +269,79 @@ struct TestFunctor {
273269
auto row_ptr = iblockrow.local_row_in_block(relBlk, lrow);
274270
for (auto lcol = 0; lcol < A.blockDim(); ++lcol) {
275271
auto entry = iblockrow.local_block_value(relBlk, lrow, lcol);
276-
check0 = check0 && (entry == row_ptr[lcol]);
277-
check1 = check1 && (entry == view_blk(lrow, lcol));
278-
check2 = check2 && (entry == valsreplace[lrow * A.blockDim() + lcol]);
272+
check0 &= (entry == row_ptr[lcol]);
273+
check1 &= (entry == view_blk(lrow, lcol));
274+
check2 &= (entry == valsreplace[lrow * A.blockDim() + lcol]);
279275
} // end local col in row
280276
} // end local row in blk
281-
d_results(7) = check0;
282-
d_results(8) = check1;
283-
d_results(9) = check2;
277+
d_results(7) &= check0;
278+
d_results(8) &= check1;
279+
d_results(9) &= check2;
284280
}
285-
286281
} // end operator()(i)
287282
}; // end TestFunctor
288283

284+
template <class BsrMatrixType, class ResultsType>
285+
struct TestConvFunctor {
286+
typedef typename BsrMatrixType::value_type scalar_t;
287+
typedef typename BsrMatrixType::ordinal_type lno_t;
288+
typedef typename BsrMatrixType::size_type size_type;
289+
290+
// Members
291+
BsrMatrixType A_;
292+
BsrMatrixType A_bsr_conv_;
293+
ResultsType d_results_;
294+
295+
// Constructor
296+
TestConvFunctor(BsrMatrixType &A, ResultsType &d_results) : A_(A), d_results_(d_results) {
297+
auto A_crs = A_.convertToCrs();
298+
A_bsr_conv_ = BsrMatrixType(A_crs, A_.blockDim());
299+
}
300+
301+
KOKKOS_INLINE_FUNCTION
302+
void operator()(const int /*rid*/) const {
303+
// Test 1: Converting to Crs and then back to Bsr should give us an indentical bsr
304+
bool check0 = true;
305+
bool check1 = true;
306+
bool check2 = true;
307+
bool check3 = true;
308+
bool check4 = true;
309+
bool check5 = true;
310+
bool check6 = true;
311+
int result_idx = 0;
312+
313+
check0 = A_bsr_conv_.numRows() == A_.numRows();
314+
check1 = A_bsr_conv_.numCols() == A_.numCols();
315+
check2 = A_bsr_conv_.blockDim() == A_.blockDim();
316+
check3 = A_bsr_conv_.nnz() == A_.nnz();
317+
318+
const auto blockRowMap1 = A_.graph.row_map;
319+
const auto blockEntries1 = A_.graph.entries;
320+
const auto blockValues1 = A_.values;
321+
322+
const auto blockRowMap2 = A_.graph.row_map;
323+
const auto blockEntries2 = A_.graph.entries;
324+
const auto blockValues2 = A_.values;
325+
326+
for (lno_t i = 0; i < A_.numRows() + 1; ++i) {
327+
check4 &= (blockRowMap1(i) == blockRowMap2(i));
328+
}
329+
330+
for (size_type i = 0; i < A_.nnz(); ++i) {
331+
check5 &= (blockEntries1(i) == blockEntries2(i));
332+
check6 &= (blockValues1(i) == blockValues2(i));
333+
}
334+
335+
d_results_(result_idx++) = check0;
336+
d_results_(result_idx++) = check1;
337+
d_results_(result_idx++) = check2;
338+
d_results_(result_idx++) = check3;
339+
d_results_(result_idx++) = check4;
340+
d_results_(result_idx++) = check5;
341+
d_results_(result_idx++) = check6;
342+
} // end operator()(i)
343+
}; // end TestConvFunctor
344+
289345
} // namespace Test_Bsr
290346

291347
// Create a CrsMatrix and BsrMatrix and test member functions.
@@ -299,13 +355,26 @@ void testBsrMatrix() {
299355
crs_matrix_type crsA = makeCrsMatrix_BlockStructure<crs_matrix_type>();
300356
bsr_matrix_type A = makeBsrMatrix<bsr_matrix_type>();
301357

302-
const int num_entries = 10;
358+
static constexpr int num_entries = 10;
303359
typedef Kokkos::View<bool[num_entries], device> result_view_type;
304360
result_view_type d_results("d_results");
361+
Kokkos::deep_copy(d_results, true);
305362
auto h_results = Kokkos::create_mirror_view(d_results);
363+
Test_Bsr::TestFunctor<bsr_matrix_type, result_view_type> functor(A, d_results);
364+
365+
Kokkos::parallel_for("KokkosSparse::Test_Bsr::BsrMatrix", Kokkos::RangePolicy<typename device::execution_space>(0, 1),
366+
functor);
367+
368+
Kokkos::deep_copy(h_results, d_results);
369+
370+
for (decltype(h_results.extent(0)) i = 0; i < h_results.extent(0); ++i) {
371+
EXPECT_EQ(h_results[i], true);
372+
}
306373

374+
Kokkos::deep_copy(d_results, true);
375+
Test_Bsr::TestConvFunctor<bsr_matrix_type, result_view_type> conv_functor(A, d_results);
307376
Kokkos::parallel_for("KokkosSparse::Test_Bsr::BsrMatrix", Kokkos::RangePolicy<typename device::execution_space>(0, 1),
308-
Test_Bsr::TestFunctor<bsr_matrix_type, result_view_type>(A, d_results));
377+
conv_functor);
309378

310379
Kokkos::deep_copy(h_results, d_results);
311380

0 commit comments

Comments
 (0)