2
2
3
3
from typing import Any , TypeVar
4
4
5
- import litellm
6
5
from litellm import completion , get_supported_openai_params # type: ignore[attr-defined]
7
6
from pydantic import BaseModel , ValidationError
8
7
14
13
def extract_with_llm (
15
14
return_type : type [T ],
16
15
user_prompt : str | list [str ],
16
+ strict : bool = False , # noqa: FBT001,FBT002
17
17
config : RAGLiteConfig | None = None ,
18
18
** kwargs : Any ,
19
19
) -> T :
@@ -33,29 +33,31 @@ class MyNameResponse(BaseModel):
33
33
"""
34
34
# Load the default config if not provided.
35
35
config = config or RAGLiteConfig ()
36
- # Update the system prompt with the JSON schema of the return type to help the LLM.
37
- system_prompt = "\n " .join (
38
- (
39
- return_type .system_prompt .strip (), # type: ignore[attr-defined]
40
- "Format your response according to this JSON schema:" ,
41
- str (return_type .model_json_schema ()),
42
- )
36
+ # Check if the LLM supports the response format.
37
+ llm_provider = "llama-cpp-python" if config .embedder .startswith ("llama-cpp" ) else None
38
+ llm_supports_response_format = "response_format" in (
39
+ get_supported_openai_params (model = config .llm , custom_llm_provider = llm_provider ) or []
43
40
)
44
- # Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
41
+ # Update the system prompt with the JSON schema of the return type to help the LLM.
42
+ system_prompt = getattr (return_type , "system_prompt" , "" ).strip ()
43
+ if not llm_supports_response_format or llm_provider == "llama-cpp-python" :
44
+ system_prompt += f"\n \n Format your response according to this JSON schema:\n { return_type .model_json_schema ()!s} "
45
+ # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
46
+ # is disabled by default because it only supports a subset of JSON schema features [2].
45
47
# [1] https://docs.litellm.ai/docs/completion/json_mode
48
+ # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
46
49
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
47
- llm_provider = "llama-cpp-python" if config .embedder .startswith ("llama-cpp" ) else None
48
50
response_format : dict [str , Any ] | None = (
49
51
{
50
52
"type" : "json_schema" ,
51
53
"json_schema" : {
52
54
"name" : return_type .__name__ ,
53
55
"description" : return_type .__doc__ or "" ,
54
56
"schema" : return_type .model_json_schema (),
57
+ "strict" : strict ,
55
58
},
56
59
}
57
- if "response_format"
58
- in (get_supported_openai_params (model = config .llm , custom_llm_provider = llm_provider ) or [])
60
+ if llm_supports_response_format
59
61
else None
60
62
)
61
63
# Concatenate the user prompt if it is a list of strings.
@@ -64,9 +66,6 @@ class MyNameResponse(BaseModel):
64
66
f'<context index="{ i + 1 } ">\n { chunk .strip ()} \n </context>'
65
67
for i , chunk in enumerate (user_prompt )
66
68
)
67
- # Enable JSON schema validation.
68
- enable_json_schema_validation = litellm .enable_json_schema_validation
69
- litellm .enable_json_schema_validation = True
70
69
# Extract structured data from the unstructured input.
71
70
for _ in range (config .llm_max_tries ):
72
71
response = completion (
@@ -89,6 +88,4 @@ class MyNameResponse(BaseModel):
89
88
else :
90
89
error_message = f"Failed to extract { return_type } from input { user_prompt } ."
91
90
raise ValueError (error_message ) from last_exception
92
- # Restore the previous JSON schema validation setting.
93
- litellm .enable_json_schema_validation = enable_json_schema_validation
94
91
return instance
0 commit comments