Skip to content

Commit 9e48b20

Browse files
authored
fix: add and enable OpenAI strict mode (#55)
1 parent f6023f5 commit 9e48b20

File tree

3 files changed

+43
-32
lines changed

3 files changed

+43
-32
lines changed

src/raglite/_eval.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import pandas as pd
8-
from pydantic import BaseModel, Field, field_validator
8+
from pydantic import BaseModel, ConfigDict, Field, field_validator
99
from sqlmodel import Session, func, select
1010
from tqdm.auto import tqdm, trange
1111

@@ -25,10 +25,11 @@ def insert_evals( # noqa: C901
2525
class QuestionResponse(BaseModel):
2626
"""A specific question about the content of a set of document contexts."""
2727

28+
model_config = ConfigDict(
29+
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
30+
)
2831
question: str = Field(
29-
...,
30-
description="A specific question about the content of a set of document contexts.",
31-
min_length=1,
32+
..., description="A specific question about the content of a set of document contexts."
3233
)
3334
system_prompt: ClassVar[str] = """
3435
You are given a set of contexts extracted from a document.
@@ -85,7 +86,7 @@ def validate_question(cls, value: str) -> str:
8586
# Extract a question from the seed chunk's related chunks.
8687
try:
8788
question_response = extract_with_llm(
88-
QuestionResponse, related_chunks, config=config
89+
QuestionResponse, related_chunks, strict=True, config=config
8990
)
9091
except ValueError:
9192
continue
@@ -101,6 +102,9 @@ def validate_question(cls, value: str) -> str:
101102
class ContextEvalResponse(BaseModel):
102103
"""Indicate whether the provided context can be used to answer a given question."""
103104

105+
model_config = ConfigDict(
106+
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
107+
)
104108
hit: bool = Field(
105109
...,
106110
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
@@ -118,7 +122,7 @@ class ContextEvalResponse(BaseModel):
118122
):
119123
try:
120124
context_eval_response = extract_with_llm(
121-
ContextEvalResponse, str(candidate_chunk), config=config
125+
ContextEvalResponse, str(candidate_chunk), strict=True, config=config
122126
)
123127
except ValueError: # noqa: PERF203
124128
pass
@@ -132,10 +136,12 @@ class ContextEvalResponse(BaseModel):
132136
class AnswerResponse(BaseModel):
133137
"""Answer a question using the provided context."""
134138

139+
model_config = ConfigDict(
140+
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
141+
)
135142
answer: str = Field(
136143
...,
137144
description="A complete answer to the given question using the provided context.",
138-
min_length=1,
139145
)
140146
system_prompt: ClassVar[str] = f"""
141147
You are given a set of contexts extracted from a document.
@@ -152,6 +158,7 @@ class AnswerResponse(BaseModel):
152158
answer_response = extract_with_llm(
153159
AnswerResponse,
154160
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
161+
strict=True,
155162
config=config,
156163
)
157164
except ValueError:

src/raglite/_extract.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import Any, TypeVar
44

5-
import litellm
65
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
76
from pydantic import BaseModel, ValidationError
87

@@ -14,6 +13,7 @@
1413
def extract_with_llm(
1514
return_type: type[T],
1615
user_prompt: str | list[str],
16+
strict: bool = False, # noqa: FBT001,FBT002
1717
config: RAGLiteConfig | None = None,
1818
**kwargs: Any,
1919
) -> T:
@@ -33,29 +33,31 @@ class MyNameResponse(BaseModel):
3333
"""
3434
# Load the default config if not provided.
3535
config = config or RAGLiteConfig()
36-
# Update the system prompt with the JSON schema of the return type to help the LLM.
37-
system_prompt = "\n".join(
38-
(
39-
return_type.system_prompt.strip(), # type: ignore[attr-defined]
40-
"Format your response according to this JSON schema:",
41-
str(return_type.model_json_schema()),
42-
)
36+
# Check if the LLM supports the response format.
37+
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
38+
llm_supports_response_format = "response_format" in (
39+
get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
4340
)
44-
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
41+
# Update the system prompt with the JSON schema of the return type to help the LLM.
42+
system_prompt = getattr(return_type, "system_prompt", "").strip()
43+
if not llm_supports_response_format or llm_provider == "llama-cpp-python":
44+
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
45+
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
46+
# is disabled by default because it only supports a subset of JSON schema features [2].
4547
# [1] https://docs.litellm.ai/docs/completion/json_mode
48+
# [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
4649
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
47-
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
4850
response_format: dict[str, Any] | None = (
4951
{
5052
"type": "json_schema",
5153
"json_schema": {
5254
"name": return_type.__name__,
5355
"description": return_type.__doc__ or "",
5456
"schema": return_type.model_json_schema(),
57+
"strict": strict,
5558
},
5659
}
57-
if "response_format"
58-
in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [])
60+
if llm_supports_response_format
5961
else None
6062
)
6163
# Concatenate the user prompt if it is a list of strings.
@@ -64,9 +66,6 @@ class MyNameResponse(BaseModel):
6466
f'<context index="{i + 1}">\n{chunk.strip()}\n</context>'
6567
for i, chunk in enumerate(user_prompt)
6668
)
67-
# Enable JSON schema validation.
68-
enable_json_schema_validation = litellm.enable_json_schema_validation
69-
litellm.enable_json_schema_validation = True
7069
# Extract structured data from the unstructured input.
7170
for _ in range(config.llm_max_tries):
7271
response = completion(
@@ -89,6 +88,4 @@ class MyNameResponse(BaseModel):
8988
else:
9089
error_message = f"Failed to extract {return_type} from input {user_prompt}."
9190
raise ValueError(error_message) from last_exception
92-
# Restore the previous JSON schema validation setting.
93-
litellm.enable_json_schema_validation = enable_json_schema_validation
9491
return instance

tests/test_extract.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import ClassVar
44

55
import pytest
6-
from pydantic import BaseModel, Field
6+
from pydantic import BaseModel, ConfigDict, Field
77

88
from raglite import RAGLiteConfig
99
from raglite._extract import extract_with_llm
@@ -13,29 +13,36 @@
1313
params=[
1414
pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
1515
pytest.param("gpt-4o-mini", id="openai"),
16-
],
16+
]
1717
)
18-
def llm(
19-
request: pytest.FixtureRequest,
20-
) -> str:
18+
def llm(request: pytest.FixtureRequest) -> str:
2119
"""Get an LLM to test RAGLite with."""
2220
llm: str = request.param
2321
return llm
2422

2523

26-
def test_extract(llm: str) -> None:
24+
@pytest.mark.parametrize(
25+
"strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")]
26+
)
27+
def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001
2728
"""Test extracting structured data."""
2829
# Set the LLM.
2930
config = RAGLiteConfig(llm=llm)
3031

31-
# Extract structured data.
32+
# Define the JSON schema of the response.
3233
class LoginResponse(BaseModel):
34+
"""The response to a login request."""
35+
36+
model_config = ConfigDict(extra="forbid" if strict else "allow")
3337
username: str = Field(..., description="The username.")
3438
password: str = Field(..., description="The password.")
3539
system_prompt: ClassVar[str] = "Extract the username and password from the input."
3640

41+
# Extract structured data.
3742
username, password = "cypher", "steak"
38-
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
43+
login_response = extract_with_llm(
44+
LoginResponse, f"{username} // {password}", strict=strict, config=config
45+
)
3946
# Validate the response.
4047
assert isinstance(login_response, LoginResponse)
4148
assert login_response.username == username

0 commit comments

Comments
 (0)