Skip to content

Commit 423abdf

Browse files
committed
[Code] Cuda csr matrix: accum logic fixes
1 parent 0557755 commit 423abdf

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

cubool/include/cubool/cubool.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ typedef enum cuBool_Hint {
9191
/** Logging hint: log includes all types of messages */
9292
CUBOOL_HINT_LOG_ALL = 0x128,
9393
/** No duplicates in the build data */
94-
CUBOOL_HINT_NO_DUPLICATES = 0x256
94+
CUBOOL_HINT_NO_DUPLICATES = 0x256,
95+
/** Performs time measurement and logs elapsed operation time */
96+
CUBOOL_HINT_TIME_CHECK = 0x512
9597
} cuBool_Hint;
9698

9799
/** Hit mask */

cubool/sources/cuda/matrix_csr_multiply.cu

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,27 @@ namespace cubool {
4040
assert(this->getNrows() == M);
4141
assert(this->getNcols() == N);
4242

43+
if (!accumulate) {
44+
// Clear all values
45+
this->mMatrixImpl.zero_dim();
46+
}
47+
4348
if (a->isMatrixEmpty() || b->isMatrixEmpty()) {
44-
// A or B has no values
49+
// Return empty matrix
4550
return;
4651
}
4752

48-
CHECK_RAISE_ERROR(accumulate, NotImplemented, "Supported only accumulated multiplication");
53+
// Ensure csr proper csr format even if empty
54+
a->resizeStorageToDim();
55+
b->resizeStorageToDim();
56+
this->resizeStorageToDim();
4957

50-
if (accumulate) {
51-
// Ensure csr proper csr format even if empty
52-
a->resizeStorageToDim();
53-
b->resizeStorageToDim();
54-
this->resizeStorageToDim();
58+
// Call backend r = c + a * b implementation, as C this is passed
59+
nsparse::spgemm_functor_t<bool, index, DeviceAlloc<index>> spgemmFunctor;
60+
auto result = spgemmFunctor(mMatrixImpl, a->mMatrixImpl, b->mMatrixImpl);
5561

56-
// Call backend r = c + a * b implementation, as C this is passed
57-
nsparse::spgemm_functor_t<bool, index, DeviceAlloc<index>> spgemmFunctor;
58-
auto result = spgemmFunctor(mMatrixImpl, a->mMatrixImpl, b->mMatrixImpl);
59-
60-
// Assign result to this
61-
this->mMatrixImpl = std::move(result);
62-
}
62+
// Assign result to this
63+
this->mMatrixImpl = std::move(result);
6364
}
6465

6566
}

0 commit comments

Comments
 (0)