@@ -40,26 +40,27 @@ namespace cubool {
40
40
assert (this ->getNrows () == M);
41
41
assert (this ->getNcols () == N);
42
42
43
+ if (!accumulate) {
44
+ // Clear all values
45
+ this ->mMatrixImpl .zero_dim ();
46
+ }
47
+
43
48
if (a->isMatrixEmpty () || b->isMatrixEmpty ()) {
44
- // A or B has no values
49
+ // Return empty matrix
45
50
return ;
46
51
}
47
52
48
- CHECK_RAISE_ERROR (accumulate, NotImplemented, " Supported only accumulated multiplication" );
53
+ // Ensure csr proper csr format even if empty
54
+ a->resizeStorageToDim ();
55
+ b->resizeStorageToDim ();
56
+ this ->resizeStorageToDim ();
49
57
50
- if (accumulate) {
51
- // Ensure csr proper csr format even if empty
52
- a->resizeStorageToDim ();
53
- b->resizeStorageToDim ();
54
- this ->resizeStorageToDim ();
58
+ // Call backend r = c + a * b implementation, as C this is passed
59
+ nsparse::spgemm_functor_t <bool , index, DeviceAlloc<index>> spgemmFunctor;
60
+ auto result = spgemmFunctor (mMatrixImpl , a->mMatrixImpl , b->mMatrixImpl );
55
61
56
- // Call backend r = c + a * b implementation, as C this is passed
57
- nsparse::spgemm_functor_t <bool , index, DeviceAlloc<index>> spgemmFunctor;
58
- auto result = spgemmFunctor (mMatrixImpl , a->mMatrixImpl , b->mMatrixImpl );
59
-
60
- // Assign result to this
61
- this ->mMatrixImpl = std::move (result);
62
- }
62
+ // Assign result to this
63
+ this ->mMatrixImpl = std::move (result);
63
64
}
64
65
65
66
}
0 commit comments