Skip to content

Commit f6023f5

Browse files
authored
fix: support embedding with LiteLLM for Ragas (#56)
1 parent abb4d1b commit f6023f5

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

src/raglite/_eval.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,34 @@ def evaluate(
210210
try:
211211
from datasets import Dataset
212212
from langchain_community.chat_models import ChatLiteLLM
213-
from langchain_community.embeddings import LlamaCppEmbeddings
214213
from langchain_community.llms import LlamaCpp
215214
from ragas import RunConfig
216215
from ragas import evaluate as ragas_evaluate
216+
from ragas.embeddings import BaseRagasEmbeddings
217217

218+
from raglite._config import RAGLiteConfig
219+
from raglite._embed import embed_sentences
218220
from raglite._litellm import LlamaCppPythonLLM
219221
except ImportError as import_error:
220222
error_message = "To use the `evaluate` function, please install the `ragas` extra."
221223
raise ImportError(error_message) from import_error
222224

225+
class RAGLiteRagasEmbeddings(BaseRagasEmbeddings):
226+
"""A RAGLite embedder for Ragas."""
227+
228+
def __init__(self, config: RAGLiteConfig | None = None):
229+
self.config = config or RAGLiteConfig()
230+
231+
def embed_query(self, text: str) -> list[float]:
232+
# Embed the input text with RAGLite's embedding function.
233+
embeddings = embed_sentences([text], config=self.config)
234+
return embeddings[0].tolist() # type: ignore[no-any-return]
235+
236+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
237+
# Embed a list of documents with RAGLite's embedding function.
238+
embeddings = embed_sentences(texts, config=self.config)
239+
return embeddings.tolist() # type: ignore[no-any-return]
240+
223241
# Create a set of answered evals if not provided.
224242
config = config or RAGLiteConfig()
225243
answered_evals_df = (
@@ -239,23 +257,12 @@ def evaluate(
239257
)
240258
else:
241259
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
242-
# Load the embedder.
243-
if not config.embedder.startswith("llama-cpp-python"):
244-
error_message = "Currently, only `llama-cpp-python` embedders are supported."
245-
raise NotImplementedError(error_message)
246-
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
247-
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
248-
model_path=embedder.model_path,
249-
n_batch=embedder.n_batch,
250-
n_ctx=embedder.n_ctx(),
251-
n_gpu_layers=-1,
252-
verbose=embedder.verbose,
253-
)
260+
embedder = RAGLiteRagasEmbeddings(config=config)
254261
# Evaluate the answered evals with Ragas.
255262
evaluation_df = ragas_evaluate(
256263
dataset=Dataset.from_pandas(answered_evals_df),
257264
llm=lc_llm,
258-
embeddings=lc_embedder,
265+
embeddings=embedder,
259266
run_config=RunConfig(max_workers=1),
260267
).to_pandas()
261268
return evaluation_df

0 commit comments

Comments
 (0)