Skip to content

Commit 99d7f0d

Browse files
authored
fix: improve output for empty databases (#68)
1 parent 2e9bfaf commit 99d7f0d

File tree

7 files changed

+52
-66
lines changed

7 files changed

+52
-66
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_
4949
Next, it is optional but recommended to install [an accelerated llama-cpp-python precompiled binary](https://github.com/abetlen/llama-cpp-python?tab=readme-ov-file#supported-backends) with:
5050

5151
```sh
52-
# Configure which llama-cpp-python precompiled binary to install (⚠️ only v0.3.2 is supported right now):
52+
# Configure which llama-cpp-python precompiled binary to install (⚠️ On macOS only v0.3.2 is supported right now):
5353
LLAMA_CPP_PYTHON_VERSION=0.3.2
5454
PYTHON_VERSION=310
5555
ACCELERATOR=metal|cu121|cu122|cu123|cu124
@@ -176,7 +176,7 @@ messages.append({
176176
"content": "How is intelligence measured?"
177177
})
178178

179-
# Adaptively decide whether to retrieve and stream the response:
179+
# Adaptively decide whether to retrieve and then stream the response:
180180
chunk_spans = []
181181
stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config)
182182
for update in stream:

src/raglite/_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def install_mcp_server(
8686
"command": "uvx",
8787
"args": [
8888
"--python",
89-
"3.12",
89+
"3.11",
9090
"--with",
9191
"numpy<2.0.0", # TODO: Remove this constraint when uv no longer needs it to solve the environment.
9292
"raglite",

src/raglite/_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass, field
66
from io import StringIO
77
from pathlib import Path
8+
from typing import Literal
89

910
from llama_cpp import llama_supports_gpu_offload
1011
from platformdirs import user_data_dir
@@ -48,8 +49,8 @@ class RAGLiteConfig:
4849
# Chunk config used to partition documents into chunks.
4950
chunk_max_size: int = 1440 # Max number of characters per chunk.
5051
# Vector search config.
51-
vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
52-
vector_search_query_adapter: bool = True
52+
vector_search_index_metric: Literal["cosine", "dot", "l1", "l2"] = "cosine"
53+
vector_search_query_adapter: bool = True # Only supported for "cosine" and "dot" metrics.
5354
# Reranking config.
5455
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = field(
5556
default_factory=lambda: (

src/raglite/_rag.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
2828
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
2929
30-
{context}
30+
<context>{context}</context>
3131
3232
{user_prompt}
3333
""".strip()
@@ -91,7 +91,9 @@ def _get_tools(
9191
"""Get tools to search the knowledge base if no RAG context is provided in the messages."""
9292
# Check if messages already contain RAG context or if the LLM supports tool use.
9393
final_message = messages[-1].get("content", "")
94-
messages_contain_rag_context = any(s in final_message for s in ("</document>", "from_chunk_id"))
94+
messages_contain_rag_context = any(
95+
s in final_message for s in ("<context>", "<document>", "from_chunk_id")
96+
)
9597
llm_supports_function_calling = supports_function_calling(config.llm)
9698
if not messages_contain_rag_context and not llm_supports_function_calling:
9799
error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling."

src/raglite/_search.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from collections import defaultdict
77
from collections.abc import Sequence
88
from itertools import groupby
9-
from typing import cast
109

1110
import numpy as np
1211
from langdetect import LangDetectException, detect
@@ -66,23 +65,32 @@ def vector_search(
6665
.order_by(distance)
6766
.limit(oversample * num_results)
6867
)
69-
chunk_ids_, distance = zip(*results, strict=True)
70-
chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
68+
results = list(results) # type: ignore[assignment]
69+
chunk_ids = np.asarray([result[0] for result in results])
70+
similarity = 1.0 - np.asarray([result[1] for result in results])
7171
elif db_backend == "sqlite":
7272
# Load the NNDescent index.
7373
index = index_metadata.get("index")
74-
ids = np.asarray(index_metadata.get("chunk_ids"))
75-
cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes")))
74+
ids = np.asarray(index_metadata.get("chunk_ids", []))
75+
cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes", [])))
7676
# Find the neighbouring multi-vector indices.
7777
from pynndescent import NNDescent
7878

79-
multi_vector_indices, distance = cast(NNDescent, index).query(
80-
query_embedding[np.newaxis, :], k=oversample * num_results
81-
)
82-
similarity = 1 - distance[0, :]
83-
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
84-
chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
85-
chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
79+
if isinstance(index, NNDescent) and len(ids) and len(cumsum):
80+
# Query the index.
81+
multi_vector_indices, distance = index.query(
82+
query_embedding[np.newaxis, :], k=oversample * num_results
83+
)
84+
similarity = 1 - distance[0, :]
85+
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
86+
chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
87+
chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
88+
else:
89+
# Empty result set if there is no index or if no chunks are indexed.
90+
chunk_ids, similarity = np.array([], dtype=np.intp), np.array([])
91+
# Exit early if there are no search results.
92+
if not len(chunk_ids):
93+
return [], []
8694
# Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
8795
# fewer hits are padded with the minimum similarity of the result set.
8896
unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True)
@@ -157,6 +165,9 @@ def reciprocal_rank_fusion(
157165
chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)}
158166
for chunk_id in chunk_ids:
159167
chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index)))
168+
# Exit early if there are no results to fuse.
169+
if not chunk_id_score:
170+
return [], []
160171
# Rank RRF results according to descending RRF score.
161172
rrf_chunk_ids, rrf_score = zip(
162173
*sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
@@ -181,6 +192,8 @@ def retrieve_chunks(
181192
chunk_ids: list[ChunkId], *, config: RAGLiteConfig | None = None
182193
) -> list[Chunk]:
183194
"""Retrieve chunks by their ids."""
195+
if not chunk_ids:
196+
return []
184197
config = config or RAGLiteConfig()
185198
engine = create_database_engine(config)
186199
with Session(engine) as session:
@@ -207,8 +220,8 @@ def rerank_chunks(
207220
if all(isinstance(chunk_id, ChunkId) for chunk_id in chunk_ids)
208221
else chunk_ids
209222
)
210-
# Early exit if no reranker is configured.
211-
if not config.reranker:
223+
# Exit early if no reranker is configured or if the input is empty.
224+
if not config.reranker or not chunks:
212225
return chunks
213226
# Select the reranker.
214227
if isinstance(config.reranker, Sequence):
@@ -243,6 +256,9 @@ def retrieve_chunk_spans(
243256
Chunk spans are ordered according to the aggregate relevance of their underlying chunks, as
244257
determined by the order in which they are provided to this function.
245258
"""
259+
# Exit early if the input is empty.
260+
if not chunk_ids:
261+
return []
246262
# Retrieve the chunks.
247263
config = config or RAGLiteConfig()
248264
chunks: list[Chunk] = (

tests/test_rag.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
"""Test RAGLite's RAG functionality."""
22

3-
import json
4-
53
from raglite import (
64
RAGLiteConfig,
75
create_rag_instruction,
86
retrieve_rag_context,
97
)
10-
from raglite._database import ChunkSpan
118
from raglite._rag import rag
129

1310

@@ -25,45 +22,3 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
2522
assert "event" in answer.lower()
2623
# Verify that no RAG context was retrieved through tool use.
2724
assert [message["role"] for message in messages] == ["user", "assistant"]
28-
29-
30-
def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
31-
"""Test Retrieval-Augmented Generation with automatic retrieval."""
32-
# Answer a question that requires RAG.
33-
user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
34-
messages = [{"role": "user", "content": user_prompt}]
35-
chunk_spans = []
36-
stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
37-
answer = ""
38-
for update in stream:
39-
assert isinstance(update, str)
40-
answer += update
41-
assert "event" in answer.lower()
42-
# Verify that RAG context was retrieved automatically.
43-
assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
44-
assert json.loads(messages[-2]["content"])
45-
assert chunk_spans
46-
assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
47-
48-
49-
def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
50-
"""Test Retrieval-Augmented Generation with automatic retrieval."""
51-
# Answer a question that does not require RAG.
52-
user_prompt = "Is 7 a prime number? Answer with Yes or No only."
53-
messages = [{"role": "user", "content": user_prompt}]
54-
chunk_spans = []
55-
stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
56-
answer = ""
57-
for update in stream:
58-
assert isinstance(update, str)
59-
answer += update
60-
assert "yes" in answer.lower()
61-
# Verify that no RAG context was retrieved.
62-
if raglite_test_config.llm.startswith("llama-cpp-python"):
63-
# Llama.cpp does not support streaming tool_choice="auto" yet, so instead we verify that the
64-
# LLM indicates that the tool call request may be skipped by checking that content is empty.
65-
assert [msg["role"] for msg in messages] == ["user", "assistant", "tool", "assistant"]
66-
assert not json.loads(messages[-2]["content"])
67-
else:
68-
assert [msg["role"] for msg in messages] == ["user", "assistant"]
69-
assert not chunk_spans

tests/test_search.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,15 @@ def test_search_no_results(raglite_test_config: RAGLiteConfig, search_method: Se
6262
assert len(chunk_ids) == len(scores) == num_results_expected
6363
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
6464
assert all(isinstance(score, float) for score in scores)
65+
66+
67+
def test_search_empty_database(llm: str, embedder: str, search_method: SearchMethod) -> None:
68+
"""Test searching for a query with an empty database."""
69+
raglite_test_config = RAGLiteConfig(db_url="sqlite:///:memory:", llm=llm, embedder=embedder)
70+
query = "supercalifragilisticexpialidocious"
71+
num_results = 5
72+
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
73+
num_results_expected = 0
74+
assert len(chunk_ids) == len(scores) == num_results_expected
75+
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
76+
assert all(isinstance(score, float) for score in scores)

0 commit comments

Comments
 (0)