6
6
from collections import defaultdict
7
7
from collections .abc import Sequence
8
8
from itertools import groupby
9
- from typing import cast
10
9
11
10
import numpy as np
12
11
from langdetect import LangDetectException , detect
@@ -66,23 +65,32 @@ def vector_search(
66
65
.order_by (distance )
67
66
.limit (oversample * num_results )
68
67
)
69
- chunk_ids_ , distance = zip (* results , strict = True )
70
- chunk_ids , similarity = np .asarray (chunk_ids_ ), 1.0 - np .asarray (distance )
68
+ results = list (results ) # type: ignore[assignment]
69
+ chunk_ids = np .asarray ([result [0 ] for result in results ])
70
+ similarity = 1.0 - np .asarray ([result [1 ] for result in results ])
71
71
elif db_backend == "sqlite" :
72
72
# Load the NNDescent index.
73
73
index = index_metadata .get ("index" )
74
- ids = np .asarray (index_metadata .get ("chunk_ids" ))
75
- cumsum = np .cumsum (np .asarray (index_metadata .get ("chunk_sizes" )))
74
+ ids = np .asarray (index_metadata .get ("chunk_ids" , [] ))
75
+ cumsum = np .cumsum (np .asarray (index_metadata .get ("chunk_sizes" , [] )))
76
76
# Find the neighbouring multi-vector indices.
77
77
from pynndescent import NNDescent
78
78
79
- multi_vector_indices , distance = cast (NNDescent , index ).query (
80
- query_embedding [np .newaxis , :], k = oversample * num_results
81
- )
82
- similarity = 1 - distance [0 , :]
83
- # Transform the multi-vector indices into chunk indices, and then to chunk ids.
84
- chunk_indices = np .searchsorted (cumsum , multi_vector_indices [0 , :], side = "right" ) + 1
85
- chunk_ids = np .asarray ([ids [chunk_index - 1 ] for chunk_index in chunk_indices ])
79
+ if isinstance (index , NNDescent ) and len (ids ) and len (cumsum ):
80
+ # Query the index.
81
+ multi_vector_indices , distance = index .query (
82
+ query_embedding [np .newaxis , :], k = oversample * num_results
83
+ )
84
+ similarity = 1 - distance [0 , :]
85
+ # Transform the multi-vector indices into chunk indices, and then to chunk ids.
86
+ chunk_indices = np .searchsorted (cumsum , multi_vector_indices [0 , :], side = "right" ) + 1
87
+ chunk_ids = np .asarray ([ids [chunk_index - 1 ] for chunk_index in chunk_indices ])
88
+ else :
89
+ # Empty result set if there is no index or if no chunks are indexed.
90
+ chunk_ids , similarity = np .array ([], dtype = np .intp ), np .array ([])
91
+ # Exit early if there are no search results.
92
+ if not len (chunk_ids ):
93
+ return [], []
86
94
# Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
87
95
# fewer hits are padded with the minimum similarity of the result set.
88
96
unique_chunk_ids , counts = np .unique (chunk_ids , return_counts = True )
@@ -157,6 +165,9 @@ def reciprocal_rank_fusion(
157
165
chunk_id_index = {chunk_id : i for i , chunk_id in enumerate (ranking )}
158
166
for chunk_id in chunk_ids :
159
167
chunk_id_score [chunk_id ] += 1 / (k + chunk_id_index .get (chunk_id , len (chunk_id_index )))
168
+ # Exit early if there are no results to fuse.
169
+ if not chunk_id_score :
170
+ return [], []
160
171
# Rank RRF results according to descending RRF score.
161
172
rrf_chunk_ids , rrf_score = zip (
162
173
* sorted (chunk_id_score .items (), key = lambda x : x [1 ], reverse = True ), strict = True
@@ -181,6 +192,8 @@ def retrieve_chunks(
181
192
chunk_ids : list [ChunkId ], * , config : RAGLiteConfig | None = None
182
193
) -> list [Chunk ]:
183
194
"""Retrieve chunks by their ids."""
195
+ if not chunk_ids :
196
+ return []
184
197
config = config or RAGLiteConfig ()
185
198
engine = create_database_engine (config )
186
199
with Session (engine ) as session :
@@ -207,8 +220,8 @@ def rerank_chunks(
207
220
if all (isinstance (chunk_id , ChunkId ) for chunk_id in chunk_ids )
208
221
else chunk_ids
209
222
)
210
- # Early exit if no reranker is configured.
211
- if not config .reranker :
223
+ # Exit early if no reranker is configured or if the input is empty .
224
+ if not config .reranker or not chunks :
212
225
return chunks
213
226
# Select the reranker.
214
227
if isinstance (config .reranker , Sequence ):
@@ -243,6 +256,9 @@ def retrieve_chunk_spans(
243
256
Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
244
257
determined by the order in which they are provided to this function.
245
258
"""
259
+ # Exit early if the input is empty.
260
+ if not chunk_ids :
261
+ return []
246
262
# Retrieve the chunks.
247
263
config = config or RAGLiteConfig ()
248
264
chunks : list [Chunk ] = (
0 commit comments