diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index ab2296c..1c93630 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -195,6 +195,7 @@ def answer_evals( num_evals: int = 100, *, config: RAGLiteConfig | None = None, + max_workers: int | None = None, ) -> "pd.DataFrame": """Read evals from the database and answer them with RAG.""" try: @@ -203,19 +204,50 @@ def answer_evals( error_message = "To use the `answer_evals` function, please install the `ragas` extra." raise ModuleNotFoundError(error_message) from import_error + from concurrent.futures import ThreadPoolExecutor, as_completed + # Read evals from the database. with Session(create_database_engine(config := config or RAGLiteConfig())) as session: evals = session.exec(select(Eval).limit(num_evals)).all() - # Answer evals with RAG. - answers: list[str] = [] - contexts: list[list[str]] = [] - for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): + + def process_eval(eval_): + """Process a single evaluation.""" chunk_spans = retrieve_context(query=eval_.question, config=config) messages = [add_context(user_prompt=eval_.question, context=chunk_spans)] response = rag(messages, config=config) answer = "".join(response) + contexts = [str(chunk_span) for chunk_span in chunk_spans] + return eval_, answer, contexts + + # Process evals in parallel + answers: list[str] = [] + contexts: list[list[str]] = [] + eval_results: list[tuple] = [] + + # Use ThreadPoolExecutor for parallel processing + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_eval = {executor.submit(process_eval, eval_): eval_ for eval_ in evals} + + # Process completed tasks with progress bar + with tqdm(total=len(evals), desc="Answering evals", unit="eval", dynamic_ncols=True) as pbar: + for future in as_completed(future_to_eval): + try: + eval_, answer, context = future.result() + eval_results.append((eval_, answer, context)) + pbar.update(1) + except Exception as exc: + eval_ = future_to_eval[future] + print(f'Eval {eval_.question} generated an exception: {exc}') + pbar.update(1) + + # Sort results to maintain original order + eval_results.sort(key=lambda x: evals.index(x[0])) + + # Extract results + for eval_, answer, context in eval_results: answers.append(answer) - contexts.append([str(chunk_span) for chunk_span in chunk_spans]) + contexts.append(context) + # Collect the answered evals. answered_evals: dict[str, list[str] | list[list[str]]] = { "question": [eval_.question for eval_ in evals], diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index 47957cc..83220f5 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -44,6 +44,7 @@ def update_query_adapter( optimize_top_k: int = 40, optimize_gap: float = 0.05, config: RAGLiteConfig | None = None, + max_workers: int | None = None, ) -> FloatMatrix: """Compute an optimal query adapter and update the database with it. @@ -138,8 +139,11 @@ def update_query_adapter( FloatMatrix The query adapter. """ + from concurrent.futures import ThreadPoolExecutor, as_completed + config = config or RAGLiteConfig() config_no_query_adapter = replace(config, vector_search_query_adapter=False) + with Session(engine := create_database_engine(config)) as session: # Get random evals from the database. chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first() @@ -150,63 +154,113 @@ def update_query_adapter( if len(evals) == 0: error_message = "First run `insert_evals()` to generate evals." raise ValueError(error_message) - # Construct the query and target matrices. - Q = np.zeros((0, len(chunk_embedding.embedding))) - T = np.zeros_like(Q) - for eval_ in tqdm( - evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False - ): - # Embed the question. - q = embed_strings([eval_.question], config=config)[0] - # Retrieve chunks that would be used to answer the question. - chunk_ids, _ = vector_search( - q, num_results=optimize_top_k, config=config_no_query_adapter - ) - retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() - retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) - # Skip this eval if it doesn't contain both relevant and irrelevant chunks. - is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks]) - if not np.any(is_relevant) or not np.any(~is_relevant): - continue - # Extract the positive and negative chunk embeddings. - P = np.vstack( - [ - chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] - for chunk in np.array(retrieved_chunks)[is_relevant] - ] - ) - N = np.vstack( - [ - chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] - for chunk in np.array(retrieved_chunks)[~is_relevant] - ] - ) - # Compute the optimal target vector t for this query embedding q. - t = _optimize_query_target(q, P, N, α=optimize_gap) - Q = np.vstack([Q, q[np.newaxis, :]]) - T = np.vstack([T, t[np.newaxis, :]]) - # Normalise the rows of Q and T. - Q /= np.linalg.norm(Q, axis=1, keepdims=True) - if config.vector_search_distance_metric == "cosine": - T /= np.linalg.norm(T, axis=1, keepdims=True) - # Compute the optimal unconstrained query adapter M. - n, d = Q.shape - M = (1 / n) * T.T @ Q - if n < d or np.linalg.matrix_rank(Q) < d: - M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q - # Compute the optimal constrained query adapter A* from M, given the distance metric. - A_star: FloatMatrix - if config.vector_search_distance_metric == "dot": - # Use the relaxed Procrustes solution. - A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d) - elif config.vector_search_distance_metric == "cosine": - # Use the orthogonal Procrustes solution. - U, _, VT = np.linalg.svd(M, full_matrices=False) - A_star = U @ VT - else: - error_message = f"Unsupported metric: {config.vector_search_distance_metric}" - raise ValueError(error_message) - # Store the optimal query adapter in the database. + + def process_eval(eval_): + """Process a single evaluation to compute query and target vectors.""" + # Create a new session for this thread + with Session(create_database_engine(config)) as thread_session: + try: + # Embed the question. + q = embed_strings([eval_.question], config=config)[0] + + # Retrieve chunks that would be used to answer the question. + chunk_ids, _ = vector_search( + q, num_results=optimize_top_k, config=config_no_query_adapter + ) + retrieved_chunks = thread_session.exec( + select(Chunk).where(col(Chunk.id).in_(chunk_ids)) + ).all() + retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) + + # Skip this eval if it doesn't contain both relevant and irrelevant chunks. + is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks]) + if not np.any(is_relevant) or not np.any(~is_relevant): + return None # Skip this eval + + # Extract the positive and negative chunk embeddings. + P = np.vstack( + [ + chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] + for chunk in np.array(retrieved_chunks)[is_relevant] + ] + ) + N = np.vstack( + [ + chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] + for chunk in np.array(retrieved_chunks)[~is_relevant] + ] + ) + + # Compute the optimal target vector t for this query embedding q. + t = _optimize_query_target(q, P, N, α=optimize_gap) + return q, t + + except Exception as e: + print(f"Error processing eval {eval_.question}: {e}") + return None + + # Process evaluations in parallel + results = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks with their original index + future_to_eval_idx = { + executor.submit(process_eval, eval_): (eval_, idx) + for idx, eval_ in enumerate(evals) + } + + # Process completed tasks with progress bar + with tqdm(total=len(evals), desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False) as pbar: + for future in as_completed(future_to_eval_idx): + try: + result = future.result() + if result is not None: # Only add valid results + eval_, idx = future_to_eval_idx[future] + results.append((idx, result)) + pbar.update(1) + except Exception as exc: + eval_, idx = future_to_eval_idx[future] + print(f'Eval {eval_.question} generated an exception: {exc}') + pbar.update(1) + + # Extract Q and T matrices from results (maintaining original order) + if not results: + raise ValueError("No valid evaluations were processed.") + + # Sort by original index to maintain order + results.sort(key=lambda x: x[0]) + + Q = np.vstack([result[1][0][np.newaxis, :] for result in results]) + T = np.vstack([result[1][1][np.newaxis, :] for result in results]) + # Ensure arrays are float32 or float64 for linear algebra operations + Q = Q.astype(np.float32) + T = T.astype(np.float32) + # Normalise the rows of Q and T. + Q /= np.linalg.norm(Q, axis=1, keepdims=True) + if config.vector_search_distance_metric == "cosine": + T /= np.linalg.norm(T, axis=1, keepdims=True) + + # Compute the optimal unconstrained query adapter M. + n, d = Q.shape + M = (1 / n) * T.T @ Q + if n < d or np.linalg.matrix_rank(Q) < d: + M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q + + # Compute the optimal constrained query adapter A* from M, given the distance metric. + A_star: FloatMatrix + if config.vector_search_distance_metric == "dot": + # Use the relaxed Procrustes solution. + A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d) + elif config.vector_search_distance_metric == "cosine": + # Use the orthogonal Procrustes solution. + U, _, VT = np.linalg.svd(M, full_matrices=False) + A_star = U @ VT + else: + error_message = f"Unsupported metric: {config.vector_search_distance_metric}" + raise ValueError(error_message) + + # Store the optimal query adapter in the database. + with Session(engine := create_database_engine(config)) as session: index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") index_metadata.metadata_["query_adapter"] = A_star flag_modified(index_metadata, "metadata_") @@ -216,4 +270,5 @@ def update_query_adapter( session.execute(text("CHECKPOINT;")) # Clear the index metadata cache to allow the new query adapter to be used. IndexMetadata._get.cache_clear() # noqa: SLF001 + return A_star