Skip to content

Parallelize query adapter code #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def answer_evals(
num_evals: int = 100,
*,
config: RAGLiteConfig | None = None,
max_workers: int | None = None,
) -> "pd.DataFrame":
"""Read evals from the database and answer them with RAG."""
try:
Expand All @@ -203,19 +204,50 @@ def answer_evals(
error_message = "To use the `answer_evals` function, please install the `ragas` extra."
raise ModuleNotFoundError(error_message) from import_error

from concurrent.futures import ThreadPoolExecutor, as_completed

# Read evals from the database.
with Session(create_database_engine(config := config or RAGLiteConfig())) as session:
evals = session.exec(select(Eval).limit(num_evals)).all()
# Answer evals with RAG.
answers: list[str] = []
contexts: list[list[str]] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):

def process_eval(eval_):
"""Process a single evaluation."""
chunk_spans = retrieve_context(query=eval_.question, config=config)
messages = [add_context(user_prompt=eval_.question, context=chunk_spans)]
response = rag(messages, config=config)
answer = "".join(response)
contexts = [str(chunk_span) for chunk_span in chunk_spans]
return eval_, answer, contexts

# Process evals in parallel
answers: list[str] = []
contexts: list[list[str]] = []
eval_results: list[tuple] = []

# Use ThreadPoolExecutor for parallel processing
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_eval = {executor.submit(process_eval, eval_): eval_ for eval_ in evals}

# Process completed tasks with progress bar
with tqdm(total=len(evals), desc="Answering evals", unit="eval", dynamic_ncols=True) as pbar:
for future in as_completed(future_to_eval):
try:
eval_, answer, context = future.result()
eval_results.append((eval_, answer, context))
pbar.update(1)
except Exception as exc:
eval_ = future_to_eval[future]
print(f'Eval {eval_.question} generated an exception: {exc}')
pbar.update(1)

# Sort results to maintain original order
eval_results.sort(key=lambda x: evals.index(x[0]))

# Extract results
for eval_, answer, context in eval_results:
answers.append(answer)
contexts.append([str(chunk_span) for chunk_span in chunk_spans])
contexts.append(context)

# Collect the answered evals.
answered_evals: dict[str, list[str] | list[list[str]]] = {
"question": [eval_.question for eval_ in evals],
Expand Down
169 changes: 112 additions & 57 deletions src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def update_query_adapter(
optimize_top_k: int = 40,
optimize_gap: float = 0.05,
config: RAGLiteConfig | None = None,
max_workers: int | None = None,
) -> FloatMatrix:
"""Compute an optimal query adapter and update the database with it.

Expand Down Expand Up @@ -138,8 +139,11 @@ def update_query_adapter(
FloatMatrix
The query adapter.
"""
from concurrent.futures import ThreadPoolExecutor, as_completed

config = config or RAGLiteConfig()
config_no_query_adapter = replace(config, vector_search_query_adapter=False)

with Session(engine := create_database_engine(config)) as session:
# Get random evals from the database.
chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
Expand All @@ -150,63 +154,113 @@ def update_query_adapter(
if len(evals) == 0:
error_message = "First run `insert_evals()` to generate evals."
raise ValueError(error_message)
# Construct the query and target matrices.
Q = np.zeros((0, len(chunk_embedding.embedding)))
T = np.zeros_like(Q)
for eval_ in tqdm(
evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False
):
# Embed the question.
q = embed_strings([eval_.question], config=config)[0]
# Retrieve chunks that would be used to answer the question.
chunk_ids, _ = vector_search(
q, num_results=optimize_top_k, config=config_no_query_adapter
)
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id))
# Skip this eval if it doesn't contain both relevant and irrelevant chunks.
is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks])
if not np.any(is_relevant) or not np.any(~is_relevant):
continue
# Extract the positive and negative chunk embeddings.
P = np.vstack(
[
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
for chunk in np.array(retrieved_chunks)[is_relevant]
]
)
N = np.vstack(
[
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
for chunk in np.array(retrieved_chunks)[~is_relevant]
]
)
# Compute the optimal target vector t for this query embedding q.
t = _optimize_query_target(q, P, N, α=optimize_gap)
Q = np.vstack([Q, q[np.newaxis, :]])
T = np.vstack([T, t[np.newaxis, :]])
# Normalise the rows of Q and T.
Q /= np.linalg.norm(Q, axis=1, keepdims=True)
if config.vector_search_distance_metric == "cosine":
T /= np.linalg.norm(T, axis=1, keepdims=True)
# Compute the optimal unconstrained query adapter M.
n, d = Q.shape
M = (1 / n) * T.T @ Q
if n < d or np.linalg.matrix_rank(Q) < d:
M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q
# Compute the optimal constrained query adapter A* from M, given the distance metric.
A_star: FloatMatrix
if config.vector_search_distance_metric == "dot":
# Use the relaxed Procrustes solution.
A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d)
elif config.vector_search_distance_metric == "cosine":
# Use the orthogonal Procrustes solution.
U, _, VT = np.linalg.svd(M, full_matrices=False)
A_star = U @ VT
else:
error_message = f"Unsupported metric: {config.vector_search_distance_metric}"
raise ValueError(error_message)
# Store the optimal query adapter in the database.

def process_eval(eval_):
"""Process a single evaluation to compute query and target vectors."""
# Create a new session for this thread
with Session(create_database_engine(config)) as thread_session:
try:
# Embed the question.
q = embed_strings([eval_.question], config=config)[0]

# Retrieve chunks that would be used to answer the question.
chunk_ids, _ = vector_search(
q, num_results=optimize_top_k, config=config_no_query_adapter
)
retrieved_chunks = thread_session.exec(
select(Chunk).where(col(Chunk.id).in_(chunk_ids))
).all()
retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id))

# Skip this eval if it doesn't contain both relevant and irrelevant chunks.
is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks])
if not np.any(is_relevant) or not np.any(~is_relevant):
return None # Skip this eval

# Extract the positive and negative chunk embeddings.
P = np.vstack(
[
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
for chunk in np.array(retrieved_chunks)[is_relevant]
]
)
N = np.vstack(
[
chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]]
for chunk in np.array(retrieved_chunks)[~is_relevant]
]
)

# Compute the optimal target vector t for this query embedding q.
t = _optimize_query_target(q, P, N, α=optimize_gap)
return q, t

except Exception as e:
print(f"Error processing eval {eval_.question}: {e}")
return None

# Process evaluations in parallel
results = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks with their original index
future_to_eval_idx = {
executor.submit(process_eval, eval_): (eval_, idx)
Copy link
Preview

Copilot AI Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Calling create_database_engine inside each thread for every eval can be expensive; consider creating the engine once outside and passing a shared session factory to avoid repeated engine initialization.

Suggested change
executor.submit(process_eval, eval_): (eval_, idx)
executor.submit(process_eval, eval_, engine): (eval_, idx)

Copilot uses AI. Check for mistakes.

for idx, eval_ in enumerate(evals)
}

# Process completed tasks with progress bar
with tqdm(total=len(evals), desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False) as pbar:
for future in as_completed(future_to_eval_idx):
try:
result = future.result()
if result is not None: # Only add valid results
eval_, idx = future_to_eval_idx[future]
results.append((idx, result))
pbar.update(1)
except Exception as exc:
eval_, idx = future_to_eval_idx[future]
print(f'Eval {eval_.question} generated an exception: {exc}')
pbar.update(1)

# Extract Q and T matrices from results (maintaining original order)
if not results:
raise ValueError("No valid evaluations were processed.")

# Sort by original index to maintain order
results.sort(key=lambda x: x[0])

Q = np.vstack([result[1][0][np.newaxis, :] for result in results])
T = np.vstack([result[1][1][np.newaxis, :] for result in results])
# Ensure arrays are float32 or float64 for linear algebra operations
Q = Q.astype(np.float32)
T = T.astype(np.float32)
# Normalise the rows of Q and T.
Q /= np.linalg.norm(Q, axis=1, keepdims=True)
if config.vector_search_distance_metric == "cosine":
T /= np.linalg.norm(T, axis=1, keepdims=True)

# Compute the optimal unconstrained query adapter M.
n, d = Q.shape
M = (1 / n) * T.T @ Q
if n < d or np.linalg.matrix_rank(Q) < d:
M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q

# Compute the optimal constrained query adapter A* from M, given the distance metric.
A_star: FloatMatrix
if config.vector_search_distance_metric == "dot":
# Use the relaxed Procrustes solution.
A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d)
elif config.vector_search_distance_metric == "cosine":
# Use the orthogonal Procrustes solution.
U, _, VT = np.linalg.svd(M, full_matrices=False)
A_star = U @ VT
else:
error_message = f"Unsupported metric: {config.vector_search_distance_metric}"
raise ValueError(error_message)

# Store the optimal query adapter in the database.
with Session(engine := create_database_engine(config)) as session:
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
index_metadata.metadata_["query_adapter"] = A_star
flag_modified(index_metadata, "metadata_")
Expand All @@ -216,4 +270,5 @@ def update_query_adapter(
session.execute(text("CHECKPOINT;"))
# Clear the index metadata cache to allow the new query adapter to be used.
IndexMetadata._get.cache_clear() # noqa: SLF001

return A_star
Loading