Skip to content

Commit 7b76933

Browse files
felixpetschkopre-commit-ci[bot]grst
authored
Improved result matrix stacking for Hamming GPU implementation (#617)
* additional numba implementation for gpu result matrix stacking * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changelog updated for GPU hamming result matrix block stacking with Numba * Lower-bound numba --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gregor Sturm <mail@gregor-sturm.de>
1 parent fbe2310 commit 7b76933

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [Unreleased]
1212

13+
### Performance improvements
14+
15+
- The stacking of the result matrix blocks of the GPU implementation of the Hamming distance metric has been reimplemented with Numba ([#617](https://github.com/scverse/scirpy/pull/617)).
16+
1317
### Fixes
1418

1519
- Ensure that clonotype network plots don't have any axis ticks ([#607](https://github.com/scverse/scirpy/pull/607)).

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"logomaker!=0.8.5",
3535
"mudata>=0.2.3",
3636
"networkx>=2.5",
37-
"numba>=0.41",
37+
"numba>=0.57",
3838
"numpy>=1.17",
3939
"pandas>=1.5,!=2.1.2",
4040
"pooch>=1.7",

src/scirpy/ir_dist/metrics.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,12 +1107,44 @@ def calc_block_gpu(
11071107
Current number: {num_elements}, Maximum number: {np.iinfo(np.int32).max}.
11081108
Consider choosing a smaller cutoff to resolve this issue."""
11091109

1110-
result_sparse = result_blocks[0]
1111-
for i in range(1, len(result_blocks)):
1112-
result_sparse += result_blocks[i]
1110+
@nb.njit
1111+
def csr_union_numba(block_data, block_indices, block_indptrs, num_rows, num_elements):
1112+
data = np.empty(num_elements, dtype=block_data[0].dtype)
1113+
indices = np.empty(num_elements, dtype=block_indices[0].dtype)
1114+
indptr = np.zeros(num_rows + 1, dtype=np.int32)
11131115

1114-
row_element_counts_gpu = np.diff(result_sparse.indptr)
1116+
ptr = 0
1117+
for row in range(num_rows):
1118+
for b in range(len(block_indptrs)):
1119+
start = block_indptrs[b][row]
1120+
end = block_indptrs[b][row + 1]
1121+
count = end - start
1122+
1123+
for j in range(count):
1124+
data[ptr + j] = block_data[b][start + j]
1125+
indices[ptr + j] = block_indices[b][start + j]
1126+
1127+
ptr += count
1128+
indptr[row + 1] = ptr
1129+
1130+
return data, indices, indptr
1131+
1132+
def csr_union(blocks):
1133+
num_rows = blocks[0].shape[0]
1134+
num_elements = sum(b.nnz for b in blocks)
11151135

1136+
block_data = [b.data for b in blocks]
1137+
block_indices = [b.indices for b in blocks]
1138+
block_indptrs = [b.indptr for b in blocks]
1139+
1140+
data, indices, indptr = csr_union_numba(block_data, block_indices, block_indptrs, num_rows, num_elements)
1141+
1142+
shape = blocks[0].shape
1143+
return csr_matrix((data, indices, indptr), shape=shape)
1144+
1145+
result_sparse = csr_union(result_blocks)
1146+
1147+
row_element_counts_gpu = np.diff(result_sparse.indptr)
11161148
result_sparse.sort_indices()
11171149

11181150
# Returns the results in a way that fits the current interface, could be improved later

0 commit comments

Comments
 (0)