Skip to content

Commit 2680b74

Browse files
authored
feat: improve late chunking and optimize pgvector settings (#51)
1 parent 0fd1970 commit 2680b74

File tree

10 files changed

+84
-66
lines changed

10 files changed

+84
-66
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_
4747
Next, it is optional but recommended to install [an accelerated llama-cpp-python precompiled binary](https://github.com/abetlen/llama-cpp-python?tab=readme-ov-file#supported-backends) with:
4848

4949
```sh
50-
# Configure which llama-cpp-python precompiled binary to install (⚠️ only v0.2.88 is supported right now):
51-
LLAMA_CPP_PYTHON_VERSION=0.2.88
50+
# Configure which llama-cpp-python precompiled binary to install (⚠️ only v0.3.2 is supported right now):
51+
LLAMA_CPP_PYTHON_VERSION=0.3.2
5252
PYTHON_VERSION=310
5353
ACCELERATOR=metal|cu121|cu122|cu123|cu124
5454
PLATFORM=macosx_11_0_arm64|linux_x86_64|win_amd64
@@ -116,7 +116,7 @@ my_config = RAGLiteConfig(
116116
my_config = RAGLiteConfig(
117117
db_url="sqlite:///raglite.sqlite",
118118
llm="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@8192",
119-
embedder="llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf",
119+
embedder="llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@1024", # A context size of 1024 tokens is the sweet spot for bge-m3.
120120
)
121121
```
122122

@@ -281,7 +281,7 @@ You can specify the database URL, LLM, and embedder directly in the Chainlit fro
281281
raglite chainlit \
282282
--db_url sqlite:///raglite.sqlite \
283283
--llm llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096 \
284-
--embedder llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf
284+
--embedder llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@1024
285285
```
286286

287287
To use an API-based LLM, make sure to include your credentials in a `.env` file or supply them inline:

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ services:
4343
- dev
4444

4545
postgres:
46-
image: pgvector/pgvector:pg16
46+
image: pgvector/pgvector:pg17
4747
environment:
4848
POSTGRES_USER: raglite_user
4949
POSTGRES_PASSWORD: raglite_password

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ spacy = ">=3.7.0,<3.8.0"
3333
# Large Language Models:
3434
huggingface-hub = ">=0.22.0"
3535
litellm = ">=1.47.1"
36-
llama-cpp-python = ">=0.2.88"
36+
llama-cpp-python = ">=0.3.2"
3737
pydantic = ">=2.7.0"
3838
# Approximate Nearest Neighbors:
3939
pynndescent = ">=0.5.12"

src/raglite/_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class RAGLiteConfig:
3333
# Embedder config used for indexing.
3434
embedder: str = field(
3535
default_factory=lambda: ( # Nomic-embed may be better if only English is used.
36-
"llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf"
36+
"llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf@1024"
3737
if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
38-
else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf"
38+
else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024"
3939
)
4040
)
4141
embedder_normalize: bool = True

src/raglite/_database.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -331,22 +331,20 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
331331
with Session(engine) as session:
332332
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
333333
session.execute(
334-
text(
335-
"""
334+
text("""
336335
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
337-
"""
338-
)
336+
""")
339337
)
340338
session.execute(
341-
text(
342-
f"""
339+
text(f"""
343340
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
344341
USING hnsw (
345-
(embedding::halfvec({embedding_dim}))
346-
halfvec_{metrics[config.vector_search_index_metric]}_ops
342+
(embedding::halfvec({embedding_dim}))
343+
halfvec_{metrics[config.vector_search_index_metric]}_ops
347344
);
348-
"""
349-
)
345+
SET hnsw.ef_search = {20 * 4 * 8};
346+
SET hnsw.iterative_scan = {'relaxed_order' if config.reranker else 'strict_order'};
347+
""")
350348
)
351349
session.commit()
352350
elif db_backend == "sqlite":
@@ -355,39 +353,31 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
355353
# [1] https://www.sqlite.org/fts5.html#external_content_tables
356354
with Session(engine) as session:
357355
session.execute(
358-
text(
359-
"""
356+
text("""
360357
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
361-
"""
362-
)
358+
""")
363359
)
364360
session.execute(
365-
text(
366-
"""
361+
text("""
367362
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
368363
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
369364
END;
370-
"""
371-
)
365+
""")
372366
)
373367
session.execute(
374-
text(
375-
"""
368+
text("""
376369
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
377370
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
378371
END;
379-
"""
380-
)
372+
""")
381373
)
382374
session.execute(
383-
text(
384-
"""
375+
text("""
385376
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
386377
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
387378
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
388379
END;
389-
"""
390-
)
380+
""")
391381
)
392382
session.commit()
393383
return engine

src/raglite/_litellm.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Add support for llama-cpp-python models to LiteLLM."""
22

33
import asyncio
4+
import contextlib
45
import logging
6+
import os
57
import warnings
68
from collections.abc import AsyncIterator, Callable, Iterator
79
from functools import cache
10+
from io import StringIO
811
from typing import Any, ClassVar, cast
912

1013
import httpx
@@ -28,7 +31,8 @@
2831
from raglite._config import RAGLiteConfig
2932

3033
# Reduce the logging level for LiteLLM and flashrank.
31-
logging.getLogger("litellm").setLevel(logging.WARNING)
34+
os.environ["LITELLM_LOG"] = "WARNING"
35+
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
3236
logging.getLogger("flashrank").setLevel(logging.WARNING)
3337

3438

@@ -96,14 +100,21 @@ def llm(model: str, **kwargs: Any) -> Llama:
96100
filename, n_ctx_str = filename_n_ctx
97101
n_ctx = int(n_ctx_str)
98102
# Load the LLM.
99-
with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN.
103+
with (
104+
contextlib.redirect_stderr(StringIO()), # Filter spurious llama.cpp output.
105+
warnings.catch_warnings(), # Filter huggingface_hub warning about HF_TOKEN.
106+
):
100107
warnings.filterwarnings("ignore", category=UserWarning)
101108
llm = Llama.from_pretrained(
102109
repo_id=repo_id,
103110
filename=filename,
104111
n_ctx=n_ctx,
105112
n_gpu_layers=-1,
106113
verbose=False,
114+
# Workaround to enable long context embedding models [1].
115+
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
116+
n_batch=n_ctx if n_ctx > 0 else 1024,
117+
n_ubatch=n_ctx if n_ctx > 0 else 1024,
107118
**kwargs,
108119
)
109120
# Enable caching.

src/raglite/_search.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626

2727

2828
def vector_search(
29-
query: str | FloatMatrix, *, num_results: int = 3, config: RAGLiteConfig | None = None
29+
query: str | FloatMatrix,
30+
*,
31+
num_results: int = 3,
32+
oversample: int = 8,
33+
config: RAGLiteConfig | None = None,
3034
) -> tuple[list[ChunkId], list[float]]:
3135
"""Search chunks using ANN vector search."""
3236
# Read the config.
@@ -57,7 +61,9 @@ def vector_search(
5761
)
5862
distance = distance_func(query_embedding).label("distance")
5963
results = session.exec(
60-
select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results)
64+
select(ChunkEmbedding.chunk_id, distance)
65+
.order_by(distance)
66+
.limit(oversample * num_results)
6167
)
6268
chunk_ids_, distance = zip(*results, strict=True)
6369
chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
@@ -70,7 +76,7 @@ def vector_search(
7076
from pynndescent import NNDescent
7177

7278
multi_vector_indices, distance = cast(NNDescent, index).query(
73-
query_embedding[np.newaxis, :], k=8 * num_results
79+
query_embedding[np.newaxis, :], k=oversample * num_results
7480
)
7581
similarity = 1 - distance[0, :]
7682
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
@@ -105,36 +111,32 @@ def keyword_search(
105111
if db_backend == "postgresql":
106112
# Convert the query to a tsquery [1].
107113
# [1] https://www.postgresql.org/docs/current/textsearch-controls.html
108-
query_escaped = re.sub(r"[&|!():<>\"]", " ", query)
114+
query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", " ", query)
109115
tsv_query = " | ".join(query_escaped.split())
110116
# Perform keyword search with tsvector.
111-
statement = text(
112-
"""
117+
statement = text("""
113118
SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
114119
FROM chunk
115120
WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
116121
ORDER BY score DESC
117122
LIMIT :limit;
118-
"""
119-
)
123+
""")
120124
results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
121125
elif db_backend == "sqlite":
122126
# Convert the query to an FTS5 query [1].
123127
# [1] https://www.sqlite.org/fts5.html#full_text_query_syntax
124-
query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", query)
128+
query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", " ", query)
125129
fts5_query = " OR ".join(query_escaped.split())
126130
# Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we
127131
# negate them to make them positive.
128132
# [1] https://www.sqlite.org/fts5.html#the_bm25_function
129-
statement = text(
130-
"""
133+
statement = text("""
131134
SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score
132135
FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid
133136
WHERE keyword_search_chunk_index MATCH :match
134137
ORDER BY score DESC
135138
LIMIT :limit;
136-
"""
137-
)
139+
""")
138140
results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
139141
# Unpack the results.
140142
results = list(results) # type: ignore[assignment]
@@ -162,12 +164,12 @@ def reciprocal_rank_fusion(
162164

163165

164166
def hybrid_search(
165-
query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
167+
query: str, *, num_results: int = 3, oversample: int = 4, config: RAGLiteConfig | None = None
166168
) -> tuple[list[ChunkId], list[float]]:
167169
"""Search chunks by combining ANN vector search with BM25 keyword search."""
168170
# Run both searches.
169-
vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config)
170-
ks_chunk_ids, _ = keyword_search(query, num_results=num_rerank, config=config)
171+
vs_chunk_ids, _ = vector_search(query, num_results=oversample * num_results, config=config)
172+
ks_chunk_ids, _ = keyword_search(query, num_results=oversample * num_results, config=config)
171173
# Combine the results with Reciprocal Rank Fusion (RRF).
172174
chunk_ids, hybrid_score = reciprocal_rank_fusion([vs_chunk_ids, ks_chunk_ids])
173175
chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results]

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def database(request: pytest.FixtureRequest) -> str:
6969
scope="session",
7070
params=[
7171
pytest.param(
72-
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
72+
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance.
7373
id="bge_m3",
7474
),
7575
pytest.param(

tests/test_rerank.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
"""Test RAGLite's reranking functionality."""
22

3+
import random
4+
from typing import TypeVar
5+
36
import pytest
47
from rerankers.models.flashrank_ranker import FlashRankRanker
58
from rerankers.models.ranker import BaseRanker
9+
from scipy.stats import kendalltau
610

711
from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks
812
from raglite._database import Chunk
913

14+
T = TypeVar("T")
15+
16+
17+
def kendall_tau(a: list[T], b: list[T]) -> float:
18+
"""Measure the Kendall rank correlation coefficient between two lists."""
19+
τ: float = kendalltau(range(len(a)), [a.index(el) for el in b])[0] # noqa: PLC2401
20+
return τ
21+
1022

1123
@pytest.fixture(
1224
params=[
@@ -40,16 +52,19 @@ def test_reranker(
4052
)
4153
# Search for a query.
4254
query = "What does it mean for two events to be simultaneous?"
43-
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config)
55+
chunk_ids, _ = hybrid_search(query, num_results=20, config=raglite_test_config)
4456
# Retrieve the chunks.
4557
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
4658
assert all(isinstance(chunk, Chunk) for chunk in chunks)
4759
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
48-
# Rerank the chunks given an inverted chunk order.
49-
reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config)
50-
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
51-
assert reranked_chunks[0] == chunks[0]
52-
# Test that we can also rerank given the chunk_ids only.
53-
reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config)
54-
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
55-
assert reranked_chunks[0] == chunks[0]
60+
# Randomly shuffle the chunks.
61+
random.seed(42)
62+
chunks_random = random.sample(chunks, len(chunks))
63+
# Rerank the chunks starting from a pathological order and verify that it improves the ranking.
64+
for arg in (chunks[::-1], chunk_ids[::-1]):
65+
reranked_chunks = rerank_chunks(query, arg, config=raglite_test_config)
66+
if reranker:
67+
τ_search = kendall_tau(chunks, reranked_chunks) # noqa: PLC2401
68+
τ_inverse = kendall_tau(chunks[::-1], reranked_chunks) # noqa: PLC2401
69+
τ_random = kendall_tau(chunks_random, reranked_chunks) # noqa: PLC2401
70+
assert τ_search >= τ_random >= τ_inverse

0 commit comments

Comments
 (0)