Skip to content

Commit f698844

Browse files
committed
v0.9.9 fix incorrect vector/matrix empty size
1 parent 39e271b commit f698844

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
### Version 0.9.9
2+
3+
* Fix #34: Empty vector/matrix multiplication returns correct shape for result
4+
15
### Version 0.9.8
26

37
* Fix #31: Remove ndim check from sparse vector ctypes wrapper to accommodate np.matrix arrays that can't be flattened

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from setuptools import setup, find_packages
33

44
DISTNAME = 'sparse_dot_mkl'
5-
VERSION = '0.9.8'
5+
VERSION = '0.9.9'
66
DESCRIPTION = "Intel MKL wrapper for sparse matrix multiplication"
77
MAINTAINER = 'Chris Jackson'
88
MAINTAINER_EMAIL = 'cj59@nyu.edu'

sparse_dot_mkl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.9.8'
1+
__version__ = '0.9.9'
22

33

44
from sparse_dot_mkl.sparse_dot import (

sparse_dot_mkl/_sparse_vector.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,21 @@ def _sparse_dot_vector(
137137
_sanity_check(mv_a, mv_b, allow_vector=True)
138138

139139
if _empty_output_check(mv_a, mv_b):
140-
output_arr = _out_matrix(
141-
(mv_a.shape[0],) if mv_b.ndim == 1 else (mv_a.shape[0], 1),
142-
_type_check(mv_a, mv_b, cast=cast, convert=False), out_arr=out
143-
)
140+
141+
if _is_dense_vector(mv_b):
142+
output_arr = _out_matrix(
143+
(mv_a.shape[0],) if mv_b.ndim == 1 else (mv_a.shape[0], 1),
144+
_type_check(mv_a, mv_b, cast=cast, convert=False),
145+
out_arr=out
146+
)
147+
148+
elif _is_dense_vector(mv_a):
149+
output_arr = _out_matrix(
150+
(mv_b.shape[1],) if mv_a.ndim == 1 else (1, mv_b.shape[1]),
151+
_type_check(mv_a, mv_b, cast=cast, convert=False),
152+
out_arr=out
153+
)
154+
144155
if out is None or (out_scalar is not None and not out_scalar):
145156
output_arr.fill(0)
146157
elif out_scalar is not None:

sparse_dot_mkl/tests/test_mkl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def make_vector(n, complex=False):
4747

4848
MATRIX_1, MATRIX_2 = make_matrixes(200, 100, 300, 0.05)
4949
MATRIX_1_EMPTY = _spsparse.csr_matrix((200, 300), dtype=np.float64)
50+
MATRIX_2_EMPTY = _spsparse.csr_matrix((300, 100), dtype=np.float64)
51+
5052
VECTOR = make_vector(300)
5153

5254

sparse_dot_mkl/tests/test_sparse_vector.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from sparse_dot_mkl import dot_product_mkl
66
from sparse_dot_mkl.tests.test_mkl import (
77
MATRIX_1,
8+
MATRIX_1_EMPTY,
9+
MATRIX_2_EMPTY,
810
MATRIX_2,
911
VECTOR,
1012
make_matrixes,
@@ -67,6 +69,7 @@ def test_mult_1d_float32(self):
6769
mat3_np = np.dot(self.mat1_d, self.mat2_d)
6870

6971
np_almost_equal(mat3_np, mat3)
72+
7073

7174
def test_mult_1d_float32_out(self):
7275
mat3_np = np.dot(self.mat1_d, self.mat2_d)
@@ -193,6 +196,17 @@ class TestSparseVectorMultiplicationBSR(TestSparseVectorMultiplication):
193196
sparse_args = {"blocksize": (10, 10)}
194197

195198

199+
class TestSparseVectorMultiplicationCSREmpty(TestSparseVectorMultiplication):
200+
201+
@classmethod
202+
def setUpClass(cls):
203+
cls.MATRIX_1, cls.MATRIX_2, cls.VECTOR = (
204+
MATRIX_1_EMPTY,
205+
MATRIX_2_EMPTY,
206+
np.zeros_like(VECTOR),
207+
)
208+
209+
196210
class TestSparseVectorMultiplicationCOO(unittest.TestCase):
197211
def setUp(self):
198212
self.mat1 = _spsparse.coo_matrix(MATRIX_1).copy()
@@ -270,6 +284,25 @@ class TestVectorSparseMultiplicationBSR(TestVectorSparseMultiplication):
270284
sparse_func = _spsparse.bsr_matrix
271285

272286

287+
class TestVectorSparseMultiplicationCSREmpty(TestVectorSparseMultiplication):
288+
289+
@classmethod
290+
def setUpClass(cls):
291+
cls.MATRIX_1, cls.MATRIX_2, cls.VECTOR = (
292+
MATRIX_1_EMPTY,
293+
MATRIX_2_EMPTY,
294+
np.zeros_like(VECTOR),
295+
)
296+
297+
298+
class TestVectorSparseMultiplicationCSCEmpty(TestVectorSparseMultiplicationCSREmpty):
299+
sparse_func = _spsparse.csc_matrix
300+
301+
302+
class TestVectorSparseMultiplicationBSREmpty(TestVectorSparseMultiplicationCSREmpty):
303+
sparse_func = _spsparse.bsr_matrix
304+
305+
273306
class TestVectorVectorMultplication(unittest.TestCase):
274307
@classmethod
275308
def setUpClass(cls):
@@ -379,6 +412,19 @@ class TestVectorSparseMultiplicationArrayCSC(TestVectorSparseMultiplication):
379412
class TestVectorSparseMultiplicationArrayBSR(TestVectorSparseMultiplication):
380413
sparse_func = bsr_array
381414

415+
class TestVectorSparseMultiplicationCSREmptyArray(TestVectorSparseMultiplicationCSREmpty):
416+
sparse_func = csr_array
417+
sparse_args = {}
418+
419+
class TestVectorSparseMultiplicationCSCEmptyArray(TestVectorSparseMultiplicationCSREmpty):
420+
sparse_func = csc_array
421+
sparse_args = {}
422+
423+
class TestVectorSparseMultiplicationBSREmptyArray(TestVectorSparseMultiplicationCSREmpty):
424+
sparse_func = bsr_array
425+
sparse_args = {"blocksize": (10, 10)}
426+
427+
382428
class TestSparseVectorMultiplicationArrayComplex(
383429
_ComplexMixin,
384430
TestSparseVectorMultiplicationArray

0 commit comments

Comments
 (0)