Skip to content

Commit 6b49ced

Browse files
authored
fix: improve structured output extraction and query adapter updates (#34)
1 parent 714c68f commit 6b49ced

File tree

4 files changed

+96
-13
lines changed

4 files changed

+96
-13
lines changed

src/raglite/_extract.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from typing import Any, TypeVar
44

5-
from litellm import completion
5+
import litellm
6+
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
67
from pydantic import BaseModel, ValidationError
78

89
from raglite._config import RAGLiteConfig
@@ -33,17 +34,39 @@ class MyNameResponse(BaseModel):
3334
# Load the default config if not provided.
3435
config = config or RAGLiteConfig()
3536
# Update the system prompt with the JSON schema of the return type to help the LLM.
36-
system_prompt = (
37-
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
38-
"Format your response according to this JSON schema:\n",
39-
return_type.model_json_schema(),
37+
system_prompt = "\n".join(
38+
(
39+
return_type.system_prompt.strip(), # type: ignore[attr-defined]
40+
"Format your response according to this JSON schema:",
41+
str(return_type.model_json_schema()),
42+
)
43+
)
44+
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
45+
# [1] https://docs.litellm.ai/docs/completion/json_mode
46+
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
47+
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
48+
response_format: dict[str, Any] | None = (
49+
{
50+
"type": "json_schema",
51+
"json_schema": {
52+
"name": return_type.__name__,
53+
"description": return_type.__doc__ or "",
54+
"schema": return_type.model_json_schema(),
55+
},
56+
}
57+
if "response_format"
58+
in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [])
59+
else None
4060
)
4161
# Concatenate the user prompt if it is a list of strings.
4262
if isinstance(user_prompt, list):
4363
user_prompt = "\n\n".join(
4464
f'<context index="{i}">\n{chunk.strip()}\n</context>'
4565
for i, chunk in enumerate(user_prompt)
4666
)
67+
# Enable JSON schema validation.
68+
enable_json_schema_validation = litellm.enable_json_schema_validation
69+
litellm.enable_json_schema_validation = True
4770
# Extract structured data from the unstructured input.
4871
for _ in range(config.llm_max_tries):
4972
response = completion(
@@ -52,7 +75,7 @@ class MyNameResponse(BaseModel):
5275
{"role": "system", "content": system_prompt},
5376
{"role": "user", "content": user_prompt},
5477
],
55-
response_format={"type": "json_object", "schema": return_type.model_json_schema()},
78+
response_format=response_format,
5679
**kwargs,
5780
)
5881
try:
@@ -66,4 +89,6 @@ class MyNameResponse(BaseModel):
6689
else:
6790
error_message = f"Failed to extract {return_type} from input {user_prompt}."
6891
raise ValueError(error_message) from last_exception
92+
# Restore the previous JSON schema validation setting.
93+
litellm.enable_json_schema_validation = enable_json_schema_validation
6994
return instance

src/raglite/_litellm.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,24 @@ def llm(model: str, **kwargs: Any) -> Llama:
129129
)
130130
return llm
131131

132+
def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, Any]:
133+
# Filter out unsupported OpenAI parameters.
134+
llama_cpp_python_params = {
135+
k: v for k, v in optional_params.items() if k in self.supported_openai_params
136+
}
137+
# Translate OpenAI's response_format [1] to llama-cpp-python's response_format [2].
138+
# [1] https://platform.openai.com/docs/guides/structured-outputs
139+
# [2] https://github.com/abetlen/llama-cpp-python#json-schema-mode
140+
if (
141+
"response_format" in llama_cpp_python_params
142+
and "json_schema" in llama_cpp_python_params["response_format"]
143+
):
144+
llama_cpp_python_params["response_format"] = {
145+
"type": "json_object",
146+
"schema": llama_cpp_python_params["response_format"]["json_schema"]["schema"],
147+
}
148+
return llama_cpp_python_params
149+
132150
def completion( # noqa: PLR0913
133151
self,
134152
model: str,
@@ -149,9 +167,7 @@ def completion( # noqa: PLR0913
149167
client: HTTPHandler | None = None,
150168
) -> ModelResponse:
151169
llm = self.llm(model)
152-
llama_cpp_python_params = {
153-
k: v for k, v in optional_params.items() if k in self.supported_openai_params
154-
}
170+
llama_cpp_python_params = self._translate_openai_params(optional_params)
155171
response = cast(
156172
CreateChatCompletionResponse,
157173
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
@@ -184,9 +200,7 @@ def streaming( # noqa: PLR0913
184200
client: HTTPHandler | None = None,
185201
) -> Iterator[GenericStreamingChunk]:
186202
llm = self.llm(model)
187-
llama_cpp_python_params = {
188-
k: v for k, v in optional_params.items() if k in self.supported_openai_params
189-
}
203+
llama_cpp_python_params = self._translate_openai_params(optional_params)
190204
stream = cast(
191205
Iterator[CreateChatCompletionStreamResponse],
192206
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),

src/raglite/_query_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Compute and update an optimal query adapter."""
22

33
import numpy as np
4+
from sqlalchemy.orm.attributes import flag_modified
45
from sqlmodel import Session, col, select
56
from tqdm.auto import tqdm
67

@@ -157,6 +158,7 @@ def update_query_adapter( # noqa: PLR0915, C901
157158
raise ValueError(error_message)
158159
# Store the optimal query adapter in the database.
159160
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
160-
index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
161+
index_metadata.metadata_["query_adapter"] = A_star
162+
flag_modified(index_metadata, "metadata_")
161163
session.add(index_metadata)
162164
session.commit()

tests/test_extract.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Test RAGLite's structured output extraction."""
2+
3+
from typing import ClassVar
4+
5+
import pytest
6+
from pydantic import BaseModel, Field
7+
8+
from raglite import RAGLiteConfig
9+
from raglite._extract import extract_with_llm
10+
11+
12+
@pytest.fixture(
13+
params=[
14+
pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
15+
pytest.param("gpt-4o-mini", id="openai"),
16+
],
17+
)
18+
def llm(
19+
request: pytest.FixtureRequest,
20+
) -> str:
21+
"""Get an LLM to test RAGLite with."""
22+
llm: str = request.param
23+
return llm
24+
25+
26+
def test_extract(llm: str) -> None:
27+
"""Test extracting structured data."""
28+
# Set the LLM.
29+
config = RAGLiteConfig(llm=llm)
30+
31+
# Extract structured data.
32+
class LoginResponse(BaseModel):
33+
username: str = Field(..., description="The username.")
34+
password: str = Field(..., description="The password.")
35+
system_prompt: ClassVar[str] = "Extract the username and password from the input."
36+
37+
username, password = "cypher", "steak"
38+
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
39+
# Validate the response.
40+
assert isinstance(login_response, LoginResponse)
41+
assert login_response.username == username
42+
assert login_response.password == password

0 commit comments

Comments
 (0)