Skip to content

Commit cd967d1

Browse files
authored
fix: improve unpacking of keyword search results (#46)
1 parent 776e0c5 commit cd967d1

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/raglite/_search.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,10 @@ def keyword_search(
129129
""")
130130
results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
131131
# Unpack the results.
132-
chunk_ids, keyword_score = zip(*results, strict=True)
133-
chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment]
134-
return chunk_ids, keyword_score # type: ignore[return-value]
132+
results = list(results) # type: ignore[assignment]
133+
chunk_ids = [result.chunk_id for result in results]
134+
keyword_score = [result.score for result in results]
135+
return chunk_ids, keyword_score
135136

136137

137138
def reciprocal_rank_fusion(

tests/test_search.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod)
4646
# Extend the chunks with their neighbours and group them into contiguous segments.
4747
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
4848
assert all(isinstance(segment, str) for segment in segments)
49+
50+
51+
def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
52+
"""Test searching for a query with no keyword search results."""
53+
query = "supercalifragilisticexpialidocious"
54+
num_results = 5
55+
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
56+
num_results_expected = 0 if search_method == keyword_search else num_results
57+
assert len(chunk_ids) == len(scores) == num_results_expected
58+
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
59+
assert all(isinstance(score, float) for score in scores)

0 commit comments

Comments
 (0)