diff --git a/src/raglite/_bench.py b/src/raglite/_bench.py index 3356c3c..2c3846f 100644 --- a/src/raglite/_bench.py +++ b/src/raglite/_bench.py @@ -2,7 +2,7 @@ import warnings from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Iterator from dataclasses import replace from functools import cached_property from pathlib import Path @@ -60,13 +60,15 @@ def trec_run_filename(self) -> str: def trec_run_filepath(self) -> Path: return self.cwd / self.trec_run_filename - def score(self) -> Generator[ScoredDoc, None, None]: + def score(self) -> Iterator[ScoredDoc]: """Read or compute a TREC run.""" if self.trec_run_filepath.exists(): yield from read_trec_run(self.trec_run_filepath.as_posix()) # type: ignore[no-untyped-call] return if not self.search("q0", next(self.dataset.queries_iter()).text): self.insert_documents() + if hasattr(self, "prescore"): + self.prescore() with self.trec_run_filepath.open(mode="w") as trec_run_file: for query in tqdm( self.dataset.queries_iter(), @@ -113,23 +115,37 @@ def insert_documents(self, max_workers: int | None = None) -> None: ] insert_documents(documents, max_workers=max_workers, config=self.config) - def update_query_adapter(self, num_evals: int = 1024) -> None: + def prescore(self) -> None: + from sqlalchemy import func, select + from sqlmodel import Session + from raglite import insert_evals, update_query_adapter - from raglite._database import IndexMetadata + from raglite._database import Eval, IndexMetadata, create_database_engine + + if not self.config.vector_search_query_adapter: + return - if ( - self.config.vector_search_query_adapter - and IndexMetadata.get(config=self.config).get("query_adapter") is None - ): - insert_evals(num_evals=num_evals, config=self.config) + required_evals = 1024 + with Session(create_database_engine(self.config)) as session: + num_evals = session.execute(select(func.count()).select_from(Eval)).scalar_one() + if num_evals < required_evals: + insert_evals(num_evals=required_evals - num_evals, config=self.config) + if IndexMetadata.get(config=self.config).get("query_adapter") is None: update_query_adapter(config=self.config) def search(self, query_id: str, query: str, *, num_results: int = 10) -> list[ScoredDoc]: - from raglite import retrieve_chunks, vector_search + from raglite import retrieve_chunks, search_and_rerank_chunks, vector_search - self.update_query_adapter() - chunk_ids, scores = vector_search(query, num_results=2 * num_results, config=self.config) - chunks = retrieve_chunks(chunk_ids, config=self.config) + if self.config.reranker: + chunks = search_and_rerank_chunks( + query=query, num_results=2 * num_results, config=self.config + ) + scores = [1 / rank for rank in range(1, len(chunks) + 1)] + else: + chunk_ids, scores = vector_search( + query, num_results=2 * num_results, config=self.config + ) + chunks = retrieve_chunks(chunk_ids, config=self.config) scored_docs = [ ScoredDoc(query_id=query_id, doc_id=chunk.document.id, score=score) for chunk, score in zip(chunks, scores, strict=True) diff --git a/src/raglite/_cli.py b/src/raglite/_cli.py index 34fcd24..a3621dc 100644 --- a/src/raglite/_cli.py +++ b/src/raglite/_cli.py @@ -132,25 +132,31 @@ def bench( ), ) -> None: """Run benchmark.""" - import ir_datasets - import ir_measures - import pandas as pd - - from raglite._bench import ( - IREvaluator, - LlamaIndexEvaluator, - OpenAIVectorStoreEvaluator, - RAGLiteEvaluator, - ) + try: + import ir_datasets + import ir_measures + import pandas as pd + from rerankers import Reranker + + from raglite._bench import ( + IREvaluator, + LlamaIndexEvaluator, + OpenAIVectorStoreEvaluator, + RAGLiteEvaluator, + ) + except ModuleNotFoundError as import_error: + error_message = "To use the `bench` command, please install the `bench` extra." + raise ModuleNotFoundError(error_message) from import_error # Initialise the benchmark. evaluator: IREvaluator measures = [ir_measures.parse_measure(measure)] index, results = [], [] - # Evaluate RAGLite (single-vector) + DuckDB HNSW + text-embedding-3-large. + # Evaluate RAGLite (single-vector) + DuckDB + text-embedding-3-large. chunk_max_size = 2048 config = RAGLiteConfig( - embedder="text-embedding-3-large", + embedder=(embedder := "text-embedding-3-large"), + reranker=None, chunk_max_size=chunk_max_size, vector_search_multivector=False, vector_search_query_adapter=False, @@ -161,9 +167,10 @@ def bench( ) index.append("RAGLite (single-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) - # Evaluate RAGLite (multi-vector) + DuckDB HNSW + text-embedding-3-large. + # Evaluate RAGLite (multi-vector) + DuckDB + text-embedding-3-large. config = RAGLiteConfig( - embedder="text-embedding-3-large", + embedder=embedder, + reranker=None, chunk_max_size=chunk_max_size, vector_search_multivector=True, vector_search_query_adapter=False, @@ -174,10 +181,11 @@ def bench( ) index.append("RAGLite (multi-vector)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) - # Evaluate RAGLite (query adapter) + DuckDB HNSW + text-embedding-3-large. + # Evaluate RAGLite (multi-vector; query adapter) + DuckDB + text-embedding-3-large. config = RAGLiteConfig( llm=(llm := "gpt-4.1"), - embedder="text-embedding-3-large", + embedder=embedder, + reranker=None, chunk_max_size=chunk_max_size, vector_search_multivector=True, vector_search_query_adapter=True, @@ -191,6 +199,29 @@ def bench( ) index.append("RAGLite (query adapter)") results.append(ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score())) + # Evaluate RAGLite (multi-vector; query adapter; reranker) + DuckDB + text-embedding-3-large. + if os.environ.get("CO_API_KEY"): + config = RAGLiteConfig( + llm=llm, + embedder=embedder, + reranker=Reranker( + "rerank-v3.5", model_type="cohere", api_key=os.environ["CO_API_KEY"], verbose=0 + ), + chunk_max_size=chunk_max_size, + vector_search_multivector=True, + vector_search_query_adapter=True, + ) + dataset = ir_datasets.load(dataset_name) + evaluator = RAGLiteEvaluator( + dataset, + insert_variant=f"multi-vector-{chunk_max_size // 4}t", + search_variant=f"query-adapter-{llm}-cohere-rerank-3.5", + config=config, + ) + index.append("RAGLite (Cohere Rerank 3.5)") + results.append( + ir_measures.calc_aggregate(measures, dataset.qrels_iter(), evaluator.score()) + ) # Evaluate LLamaIndex + FAISS HNSW + text-embedding-3-large. dataset = ir_datasets.load(dataset_name) evaluator = LlamaIndexEvaluator(dataset) diff --git a/src/raglite/_insert.py b/src/raglite/_insert.py index bd63c29..206c488 100644 --- a/src/raglite/_insert.py +++ b/src/raglite/_insert.py @@ -170,6 +170,7 @@ def insert_documents( # noqa: C901 session.flush() # Flush changes to the database. session.expunge_all() # Release memory of flushed changes. num_unflushed_embeddings = 0 + pbar.set_postfix({"id": document_record.id}) pbar.update() session.commit() if engine.dialect.name == "duckdb": diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index 47957cc..3e699df 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -1,11 +1,16 @@ """Compute and update an optimal query adapter.""" -# ruff: noqa: N806 +# ruff: noqa: N803, N806, PLC2401, PLR0913, RUF003 +import contextlib +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import replace +from functools import partial import numpy as np -from scipy.optimize import lsq_linear +from scipy.optimize import lsq_linear, minimize +from scipy.special import expit from sqlalchemy import text from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, col, select @@ -18,31 +23,158 @@ from raglite._typing import FloatMatrix, FloatVector +def _extract_triplets( + eval_: Eval, *, optimize_top_k: int, config: RAGLiteConfig +) -> tuple[FloatVector, FloatMatrix, FloatMatrix]: + with Session(create_database_engine(config)) as session: + # Embed the question. + q = embed_strings([eval_.question], config=config)[0] + # Retrieve chunks that would be used to answer the question. + chunk_ids, _ = vector_search(q, num_results=optimize_top_k, config=config) + retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() + retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) + # Skip this eval if it doesn't contain both relevant and irrelevant chunks. + is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks]) + if not np.any(is_relevant) or not np.any(~is_relevant): + error_message = "Eval does not contain both relevant and irrelevant chunks." + raise ValueError(error_message) + # Extract the positive and negative chunk embeddings. + P = np.vstack( + [ + chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] + for chunk in np.array(retrieved_chunks)[is_relevant] + ] + ) + N = np.vstack( + [ + chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] + for chunk in np.array(retrieved_chunks)[~is_relevant] + ] + ) + return q, P, N + + def _optimize_query_target( q: FloatVector, - P: FloatMatrix, # noqa: N803, - N: FloatMatrix, # noqa: N803, + P: FloatMatrix, + N: FloatMatrix, *, - α: float = 0.05, # noqa: PLC2401 + α: float = 0.05, ) -> FloatVector: # Convert to double precision for the optimizer. - q_dtype = q.dtype q, P, N = q.astype(np.float64), P.astype(np.float64), N.astype(np.float64) - # Construct the constraint matrix D := P - (1 + α) * N. # noqa: RUF003 + # Construct the constraint matrix D := P - (1 + α) * N. D = np.reshape(P[:, np.newaxis, :] - (1.0 + α) * N[np.newaxis, :, :], (-1, P.shape[1])) # Solve the dual problem min_μ ½ ‖q + Dᵀ μ‖² s.t. μ ≥ 0. A, b = D.T, -q - μ_star = lsq_linear(A, b, bounds=(0.0, np.inf), tol=np.finfo(A.dtype).eps).x # noqa: PLC2401 + result = lsq_linear(A, b, bounds=(0.0, np.inf), tol=np.finfo(A.dtype).eps) + μ_star = result.x # Recover the primal solution q* = q + Dᵀ μ*. - q_star: FloatVector = (q + D.T @ μ_star).astype(q_dtype) + q_star: FloatVector = q + D.T @ μ_star return q_star -def update_query_adapter( +def _compute_query_adapter( + w: FloatVector, Q: FloatMatrix, T: FloatMatrix, PT: FloatMatrix, config: RAGLiteConfig +) -> FloatMatrix: + # Compute the weighted query embeddings. + n, d = Q.shape + T_prime = Q + (w[:, np.newaxis] ** 2) * (T - Q) + M = (1.0 / n) * T_prime.T @ Q + PT + # Compute the optimal constrained query adapter A* from M, given the distance metric. + A_star: FloatMatrix + if config.vector_search_distance_metric == "dot": + # Use the relaxed Procrustes solution. + A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d) + elif config.vector_search_distance_metric == "cosine": + # Use the orthogonal Procrustes solution. + U, _, VT = np.linalg.svd(M, full_matrices=False) + A_star = U @ VT + return A_star + + +def _compute_query_adapter_grad( + w: FloatVector, + Q: FloatMatrix, + T: FloatMatrix, + PT: FloatMatrix, + config: RAGLiteConfig, +) -> Iterator[FloatMatrix]: + n, d = Q.shape + diff = T - Q + T_prime = Q + (w[:, np.newaxis] ** 2) * diff + M = (1.0 / n) * T_prime.T @ Q + PT + if config.vector_search_distance_metric == "dot": + fro = np.linalg.norm(M, ord="fro") + if fro <= np.sqrt(np.finfo(M.dtype).eps): + for _ in range(n): + yield np.zeros((d, d), dtype=M.dtype) + return + for i in range(n): + outer = 2.0 * w[i] * np.outer(diff[i], Q[i]) / n + inner = np.sum(outer * M) + yield (outer - (inner / fro**2) * M) / fro * np.sqrt(d) + elif config.vector_search_distance_metric == "cosine": + U, σ, VT = np.linalg.svd(M, full_matrices=False) + X = diff @ U + Y = Q @ VT.T + denom = σ[:, np.newaxis] + σ[np.newaxis, :] + denom[denom <= (np.finfo(M.dtype).eps ** (1 / 4)) * σ[0]] = np.inf + for i in range(n): + outer = 2.0 * w[i] * np.outer(X[i], Y[i]) + core = (outer - outer.T) / denom / n + yield U @ core @ VT + + +def _objective_function( + w: FloatVector, + Q_train: FloatMatrix, + T_train: FloatMatrix, + PT_train: FloatMatrix, + Q: FloatMatrix, + D: FloatMatrix, + config: RAGLiteConfig, +) -> float: + A = _compute_query_adapter(w, Q_train, T_train, PT_train, config) + gaps = np.sum(D * (Q @ A.T), axis=1) + factor = 1.28 / (0.05 / 0.75) # TODO: Use gap_margin here instead of 0.05. + neg_filter = expit(-factor * gaps) + mean_neg_gap = np.mean(gaps * neg_filter) + cost: float = -mean_neg_gap + return cost + + +def _gradient( + w: FloatVector, + Q_train: FloatMatrix, + T_train: FloatMatrix, + PT_train: FloatMatrix, + Q: FloatMatrix, + D: FloatMatrix, + config: RAGLiteConfig, +) -> FloatVector: + dAdw = _compute_query_adapter_grad(w, Q_train, T_train, PT_train, config) + A = _compute_query_adapter(w, Q_train, T_train, PT_train, config) + gaps = np.sum(D * (Q @ A.T), axis=1) + factor = 1.28 / (0.05 / 0.75) # TODO: Use gap_margin here instead of 0.05. + neg_filter = expit(-factor * gaps) + weights = neg_filter - factor * gaps * neg_filter * (1.0 - neg_filter) + S = D.T @ (weights[:, np.newaxis] * Q) + grad = np.empty_like(w) + for i, dAdwi in enumerate(dAdw): + grad[i] = -np.sum(dAdwi * S) / len(Q) + return grad + + +def update_query_adapter( # noqa: PLR0915 *, max_evals: int = 4096, optimize_top_k: int = 40, - optimize_gap: float = 0.05, + gap_margin: float = 0.05, + gap_max_iter: int = 40, + gap_tol: float = 1e-4, + gap_validation_size: float = 0.0, + max_workers: int | None = None, config: RAGLiteConfig | None = None, ) -> FloatMatrix: """Compute an optimal query adapter and update the database with it. @@ -110,6 +242,15 @@ def update_query_adapter( μ* := argmin ½ ||qᵢ + Dᵢᵀ μ||² s.t. μ >= 0 + Finally, we weight the optimal target vectors with a set of weights w* that are optimised to + maximise the gap between the positive and negative chunks in a validation set of evals: + + w* := argmax ΣᵢΣₘₙ (pₘ⁽ᵛᵃˡ⁾ - nₙ⁽ᵛᵃˡ⁾)ᵀ A(w) qᵢ⁽ᵛᵃˡ⁾ + + where A(w) is the weighted query adapter (1 / n) Tᵀ diag(w) Q + P, T is the matrix of optimal + target vectors t*, and P is the passthrough matrix 𝕀 - Qᵀ (Q Qᵀ)⁺ Q that lets query vectors + outside of the row space of Q through unaffected. + Parameters ---------- max_evals @@ -117,10 +258,18 @@ def update_query_adapter( rank-one update of the query adapter A. optimize_top_k The number of search results per eval to optimize. - optimize_gap - The strength of the query adapter, expressed as a nonnegative number. Should be large enough - to correct incorrectly ranked results, but small enough to not affect correctly ranked - results. + gap_margin + The margin α to use when computing the optimal query target t* for each query embedding qᵢ. + gap_max_iter + The maximum number of iterations to use to optimize the query target weights w*. + gap_tol + The tolerance to use when optimizing the query target weights w*. + gap_validation_size + The fraction of evals to use for optimizing the query target weights w*. The remaining evals + are used for training the unweighted query adapter. + max_workers + The maximum number of worker threads to use for triplet extraction and query target + optimization. config The RAGLite config to use to construct and store the query adapter. @@ -130,6 +279,8 @@ def update_query_adapter( If no documents have been inserted into the database yet. ValueError If no evals have been inserted into the database yet. + ValueError + If no evals are usable for optimization. ValueError If the `config.vector_search_distance_metric` is not supported. @@ -150,62 +301,109 @@ def update_query_adapter( if len(evals) == 0: error_message = "First run `insert_evals()` to generate evals." raise ValueError(error_message) - # Construct the query and target matrices. - Q = np.zeros((0, len(chunk_embedding.embedding))) - T = np.zeros_like(Q) - for eval_ in tqdm( - evals, desc="Optimizing evals", unit="eval", dynamic_ncols=True, leave=False + # Collect triplets (qᵢ, Pᵢ × Nᵢ) for each eval. + with ( + ThreadPoolExecutor(max_workers=max_workers) as executor, + tqdm( + total=len(evals), desc="Extracting triplets", unit="eval", dynamic_ncols=True + ) as pbar, ): - # Embed the question. - q = embed_strings([eval_.question], config=config)[0] - # Retrieve chunks that would be used to answer the question. - chunk_ids, _ = vector_search( - q, num_results=optimize_top_k, config=config_no_query_adapter - ) - retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() - retrieved_chunks = sorted(retrieved_chunks, key=lambda chunk: chunk_ids.index(chunk.id)) - # Skip this eval if it doesn't contain both relevant and irrelevant chunks. - is_relevant = np.array([chunk.id in eval_.chunk_ids for chunk in retrieved_chunks]) - if not np.any(is_relevant) or not np.any(~is_relevant): - continue - # Extract the positive and negative chunk embeddings. - P = np.vstack( - [ - chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] - for chunk in np.array(retrieved_chunks)[is_relevant] - ] - ) - N = np.vstack( - [ - chunk.embedding_matrix[[np.argmax(chunk.embedding_matrix @ q)]] - for chunk in np.array(retrieved_chunks)[~is_relevant] - ] + q: list[FloatVector] = [] + P: list[FloatMatrix] = [] + N: list[FloatMatrix] = [] + futures = [ + executor.submit( + partial( + _extract_triplets, + optimize_top_k=optimize_top_k, + config=config_no_query_adapter, + ), + eval_, + ) + for eval_ in evals + ] + for future in as_completed(futures): + with contextlib.suppress(Exception): + pbar.update() + qi, Pi, Ni = future.result() + q.append(qi) + P.append(Pi) + N.append(Ni) + # Exit if there are no triplets to optimise. + if len(q) == 0: + error_message = "No evals found with incorrectly ranked results to optimize." + raise ValueError(error_message) + # Split in train and validation sets. + val_size = round(gap_validation_size * len(q)) + train_size = len(q) - val_size + # Compute the optimal query targets T. + with ThreadPoolExecutor(max_workers=max_workers) as executor: + T_train = np.vstack( + list( + tqdm( + executor.map( + partial(_optimize_query_target, α=gap_margin), + q[:train_size], + P[:train_size], + N[:train_size], + ), + total=train_size, + desc="Optimizing query targets", + unit="query", + dynamic_ncols=True, + ) + ) ) - # Compute the optimal target vector t for this query embedding q. - t = _optimize_query_target(q, P, N, α=optimize_gap) - Q = np.vstack([Q, q[np.newaxis, :]]) - T = np.vstack([T, t[np.newaxis, :]]) # Normalise the rows of Q and T. + Q = np.vstack(q).astype(np.float64) Q /= np.linalg.norm(Q, axis=1, keepdims=True) if config.vector_search_distance_metric == "cosine": - T /= np.linalg.norm(T, axis=1, keepdims=True) - # Compute the optimal unconstrained query adapter M. - n, d = Q.shape - M = (1 / n) * T.T @ Q + T_train /= np.linalg.norm(T_train, axis=1, keepdims=True) + # Search for the optimal gap α* on a subset of the triplets. + w_star = np.ones(train_size) + # Compute a passthrough matrix. + Q_train = Q[:train_size] + n, d = Q_train.shape if n < d or np.linalg.matrix_rank(Q) < d: - M += np.eye(d) - Q.T @ np.linalg.pinv(Q @ Q.T) @ Q - # Compute the optimal constrained query adapter A* from M, given the distance metric. - A_star: FloatMatrix - if config.vector_search_distance_metric == "dot": - # Use the relaxed Procrustes solution. - A_star = M / np.linalg.norm(M, ord="fro") * np.sqrt(d) - elif config.vector_search_distance_metric == "cosine": - # Use the orthogonal Procrustes solution. - U, _, VT = np.linalg.svd(M, full_matrices=False) - A_star = U @ VT + PT_train = np.eye(d) - Q_train.T @ np.linalg.pinv(Q_train @ Q_train.T) @ Q_train else: - error_message = f"Unsupported metric: {config.vector_search_distance_metric}" - raise ValueError(error_message) + PT_train = np.zeros((d, d)) + # Construct the delta matrix D[i, :] := mean([pₘᵀ - nₙᵀ]ₘₙ, axis=0). + Q_full = [] + D_full = [] + for qi, Pi, Ni in zip(Q, P, N, strict=True): + D = np.reshape(Pi[:, np.newaxis, :] - Ni[np.newaxis, :, :], (-1, d)) + D_full.append(D) + Q_full.append(np.repeat(qi[np.newaxis, :], D.shape[0], axis=0)) + # Compute the optimal gap α*. + with tqdm( + total=gap_max_iter, + desc="Optimizing query target weights", + unit="iter", + dynamic_ncols=True, + ) as pbar: + result = minimize( + _objective_function, + jac=_gradient, + x0=np.ones(train_size), + args=( + Q_train, + T_train, + PT_train, + np.vstack(Q_full), + np.vstack(D_full), + config_no_query_adapter, + ), + method="L-BFGS-B", + callback=lambda intermediate_result: ( + pbar.update(), + pbar.set_postfix({"gap": -intermediate_result.fun}), + ), + options={"ftol": gap_tol, "maxiter": gap_max_iter, "maxcor": 10, "maxls": 10}, + ) + w_star = result.x + # Compute the optimal query adapter. + A_star = _compute_query_adapter(w_star, Q_train, T_train, PT_train, config) # Store the optimal query adapter in the database. index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") index_metadata.metadata_["query_adapter"] = A_star diff --git a/tests/conftest.py b/tests/conftest.py index 7b92894..dea77b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import os import socket import tempfile -from collections.abc import Generator +from collections.abc import Iterator from pathlib import Path import pytest @@ -48,7 +48,7 @@ def pytest_sessionstart(session: pytest.Session) -> None: @pytest.fixture(scope="session") -def duckdb_url() -> Generator[str, None, None]: +def duckdb_url() -> Iterator[str]: """Create a temporary DuckDB database file and return the database URL.""" with tempfile.TemporaryDirectory() as temp_dir: db_file = Path(temp_dir) / "raglite_test.db" diff --git a/tests/test_query_adapter.py b/tests/test_query_adapter.py index 7b792e5..4021bb7 100644 --- a/tests/test_query_adapter.py +++ b/tests/test_query_adapter.py @@ -1,12 +1,15 @@ """Test RAGLite's query adapter.""" from dataclasses import replace +from typing import Literal import numpy as np import pytest +from scipy.optimize import check_grad from raglite import RAGLiteConfig, insert_evals, update_query_adapter, vector_search from raglite._database import IndexMetadata +from raglite._query_adapter import _gradient, _objective_function @pytest.mark.slow @@ -19,7 +22,7 @@ def test_query_adapter(raglite_test_config: RAGLiteConfig) -> None: Q = IndexMetadata.get("default", config=config_without_query_adapter).get("query_adapter") # noqa: N806 assert Q is None # Insert evals. - insert_evals(num_evals=2, max_chunks_per_eval=10, config=config_with_query_adapter) + insert_evals(num_evals=10, max_chunks_per_eval=10, config=config_with_query_adapter) # Update the query adapter. A = update_query_adapter(config=config_with_query_adapter) # noqa: N806 assert isinstance(A, np.ndarray) @@ -38,3 +41,46 @@ def test_query_adapter(raglite_test_config: RAGLiteConfig) -> None: _, scores_qa = vector_search(query, config=config_with_query_adapter) _, scores_no_qa = vector_search(query, config=config_without_query_adapter) assert scores_qa != scores_no_qa + + +@pytest.mark.parametrize( + "metric", + [ + pytest.param("cosine", id="metric=cosine"), + pytest.param("dot", id="metric=dot"), + ], +) +@pytest.mark.parametrize( + "embedding_dim", + [ + pytest.param(16, id="embedding_dim=16"), + pytest.param(128, id="embedding_dim=128"), + ], +) +@pytest.mark.parametrize( + "num_evals", + [ + pytest.param(16, id="num_evals=16"), + pytest.param(128, id="num_evals=128"), + ], +) +def test_query_adapter_grad( + num_evals: int, embedding_dim: int, metric: Literal["cosine", "dot"] +) -> None: + """Verify that the query adapter gradient is correct.""" + # Generate test data. + num_val = round(0.2 * num_evals) + num_train = num_evals - num_val + rng = np.random.default_rng(42) + w0 = np.abs(rng.normal(size=num_train)) + Q_train = rng.normal(size=(num_train, embedding_dim)) # noqa: N806 + T_train = rng.normal(size=(num_train, embedding_dim)) # noqa: N806 + PT_train = rng.normal(size=(embedding_dim, embedding_dim)) # noqa: N806 + Q_val = rng.normal(size=(num_val, embedding_dim)) # noqa: N806 + D_val = rng.normal(size=(num_val, embedding_dim)) # noqa: N806 + config = RAGLiteConfig(vector_search_distance_metric=metric) + # Check the gradient. + l2_residual = check_grad( + _objective_function, _gradient, w0, Q_train, T_train, PT_train, Q_val, D_val, config + ) + assert (l2_residual / len(w0)) <= 100 * np.sqrt(np.finfo(w0.dtype).eps)