Skip to content

Commit 8bc4c73

Browse files
Geoffrey GrossmanGeoffrey Grossman
Geoffrey Grossman
authored and
Geoffrey Grossman
committed
fix: openai schema format
1 parent e0d1bd1 commit 8bc4c73

File tree

1 file changed

+113
-70
lines changed

1 file changed

+113
-70
lines changed

flat_ai/flat_ai.py

Lines changed: 113 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,42 @@
2121
Author: Your Friendly Neighborhood AI Wrangler
2222
"""
2323

24-
2524
import inspect
2625
import json
2726
import re
28-
from typing import Any, Callable, Dict, List, Optional, Type, Union, Iterable
27+
import time
28+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Type, Union
29+
2930
import openai
3031
from pydantic import BaseModel, Field
31-
from typing import Literal
32-
import time
32+
3333
from flat_ai.trace_llm import MyOpenAI
3434

3535
openai.OpenAI = MyOpenAI
36+
37+
3638
class Boolean(BaseModel):
37-
result: bool = Field(description="The true/false result based on the question and context")
39+
result: bool = Field(
40+
description="The true/false result based on the question and context"
41+
)
42+
3843

3944
class FlatAI:
40-
def __init__(self, client: Optional[openai.OpenAI] = None, model: str = "gpt-4", retries: int = 3, base_url: str = "https://api.openai.com/v1", api_key: Optional[str] = None):
45+
def __init__(
46+
self,
47+
client: Optional[openai.OpenAI] = None,
48+
model: str = "gpt-4",
49+
retries: int = 3,
50+
base_url: str = "https://api.openai.com/v1",
51+
api_key: Optional[str] = None,
52+
):
4153
if client:
4254
self.client = client
4355
elif api_key:
44-
self.client = openai.OpenAI(
45-
base_url=base_url,
46-
api_key=api_key
47-
)
56+
self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
4857
else:
4958
raise ValueError("Must provide either client or api_key")
50-
59+
5160
self.model = model
5261
self.retries = retries
5362
self._context = {}
@@ -63,13 +72,15 @@ def _retry_on_error(self, func: Callable, *args, **kwargs) -> Any:
6372
if attempt < self.retries - 1:
6473
time.sleep(1 * (attempt + 1)) # Exponential backoff
6574
continue
66-
raise Exception(f"Operation failed after {self.retries} attempts. Last error: {str(last_exception)}")
75+
raise Exception(
76+
f"Operation failed after {self.retries} attempts. Last error: {str(last_exception)}"
77+
)
6778

6879
def set_context(self, **kwargs):
6980
"""Set the context for future LLM interactions"""
7081
self._context = kwargs
7182

72-
def add_context(self, **kwargs):
83+
def add_context(self, **kwargs):
7384
"""Add additional context while preserving existing context"""
7485
self._context.update(kwargs)
7586

@@ -85,15 +96,17 @@ def delete_from_context(self, *keys):
8596
def _build_messages(self, *message_parts, **kwargs) -> List[Dict[str, str]]:
8697
"""Build message list with context as system message if present"""
8798
messages = []
88-
99+
89100
if self._context:
90101
context_dict = {}
91102
for key, value in self._context.items():
92103
if isinstance(value, BaseModel):
93104
context_dict[key] = json.loads(value.model_dump_json())
94105
else:
95106
context_dict[key] = str(value)
96-
messages.append({"role": "system", "content": json.dumps(context_dict, indent=2)})
107+
messages.append(
108+
{"role": "system", "content": json.dumps(context_dict, indent=2)}
109+
)
97110

98111
if kwargs:
99112
extra_context_dict = {}
@@ -102,34 +115,41 @@ def _build_messages(self, *message_parts, **kwargs) -> List[Dict[str, str]]:
102115
extra_context_dict[key] = json.loads(value.model_dump_json())
103116
else:
104117
extra_context_dict[key] = str(value)
105-
messages.append({"role": "system", "content": json.dumps(extra_context_dict, indent=2)})
118+
messages.append(
119+
{"role": "system", "content": json.dumps(extra_context_dict, indent=2)}
120+
)
106121

107122
messages.extend(message_parts)
108123
return messages
109124

110125
def is_true(self, _question: str, **kwargs) -> bool:
111126
class IsItTrue(BaseModel):
112127
is_it_true: bool
128+
113129
"""Ask a yes/no question and get a boolean response"""
114130
ret = self.generate_object(IsItTrue, question=_question, **kwargs)
115131
return ret.is_it_true
116132

117133
def classify(self, options: Dict[str, str], **kwargs) -> str:
118134
"""Get a key from provided options based on context"""
135+
119136
class Classification(BaseModel):
120-
choice: str = Field(description="The selected classification key", enum=list(options.keys()))
137+
choice: str = Field(
138+
description="The selected classification key", enum=list(options.keys())
139+
)
121140

122141
def _execute():
123142
if not options:
124143
raise ValueError("Options dictionary cannot be empty")
125-
144+
126145
result = self.generate_object(Classification, options=options)
127146
return result.choice
128-
147+
129148
return self._retry_on_error(_execute)
130149

131150
def generate_object(self, schema_class: Type[BaseModel | Any], **kwargs) -> Any:
132151
"""Generate an object matching the provided schema"""
152+
133153
def _execute():
134154
# Handle typing generics (List, Dict, etc)
135155
if hasattr(schema_class, "__origin__"):
@@ -138,7 +158,7 @@ def _execute():
138158
if issubclass(item_type, BaseModel):
139159
schema = {
140160
"type": "array",
141-
"items": item_type.model_json_schema()
161+
"items": item_type.model_json_schema(),
142162
}
143163
else:
144164
schema = {
@@ -148,9 +168,9 @@ def _execute():
148168
str: "string",
149169
int: "integer",
150170
float: "number",
151-
bool: "boolean"
171+
bool: "boolean",
152172
}.get(item_type, "string")
153-
}
173+
},
154174
}
155175
else:
156176
raise ValueError(f"Unsupported generic type: {schema_class}")
@@ -163,7 +183,7 @@ def _execute():
163183
str: "string",
164184
int: "integer",
165185
float: "number",
166-
bool: "boolean"
186+
bool: "boolean",
167187
}.get(schema_class, "string")
168188
}
169189
if schema["type"] == "array":
@@ -175,98 +195,130 @@ def _execute():
175195
raise ValueError(f"Unsupported schema type: {schema_class}")
176196

177197
messages = self._build_messages(
178-
{"role": "user", "content": "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data."},
179-
**kwargs
198+
{
199+
"role": "user",
200+
"content": "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data.",
201+
},
202+
**kwargs,
180203
)
181204

182-
response = self.client.chat.completions.create(
183-
model=self.model,
184-
response_format={"type": "json_object", "schema": schema},
185-
messages=messages
186-
)
187-
205+
if isinstance(self.client, MyOpenAI):
206+
response = self.client.chat.completions.create(
207+
model=self.model,
208+
response_format={
209+
"type": "json_schema",
210+
"json_schema": {
211+
"name": schema_class.__name__,
212+
"schema": schema,
213+
},
214+
},
215+
messages=messages,
216+
)
217+
else:
218+
response = self.client.chat.completions.create(
219+
model=self.model,
220+
response_format={"type": "json_object", "schema": schema},
221+
messages=messages,
222+
)
223+
188224
result = json.loads(response.choices[0].message.content)
189-
225+
190226
# Handle list of Pydantic models
191-
if (hasattr(schema_class, "__origin__") and
192-
schema_class.__origin__ is list and
193-
issubclass(schema_class.__args__[0], BaseModel)):
194-
return [schema_class.__args__[0].model_validate(item) for item in result]
227+
if (
228+
hasattr(schema_class, "__origin__")
229+
and schema_class.__origin__ is list
230+
and issubclass(schema_class.__args__[0], BaseModel)
231+
):
232+
return [
233+
schema_class.__args__[0].model_validate(item) for item in result
234+
]
195235
# Handle single Pydantic model
196236
elif isinstance(schema_class, type) and issubclass(schema_class, BaseModel):
197237
return schema_class.model_validate(result)
198-
238+
199239
return result
240+
200241
return self._retry_on_error(_execute)
201242

202243
def call_function(self, func: Callable, **kwargs) -> Any:
203244
"""Call a function with AI-determined arguments"""
204245
func, args = self.pick_a_function("", [func], **kwargs)
205246
return func(**args)
206247

207-
def pick_a_function(self, instructions: str, functions: List[Callable], **kwargs) -> tuple[Callable, Dict]:
248+
def pick_a_function(
249+
self, instructions: str, functions: List[Callable], **kwargs
250+
) -> tuple[Callable, Dict]:
208251
"""Pick appropriate function and arguments based on instructions"""
252+
209253
def _execute():
210254
tools = [create_openai_function_description(func) for func in functions]
211255

212256
messages = self._build_messages(
213257
{"role": "system", "content": instructions},
214-
{"role": "user", "content": "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome."},
215-
**kwargs
258+
{
259+
"role": "user",
260+
"content": "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome.",
261+
},
262+
**kwargs,
216263
)
217264

218265
response = self.client.chat.completions.create(
219-
model=self.model,
220-
messages=messages,
221-
tools=tools
266+
model=self.model, messages=messages, tools=tools
222267
)
223268

224269
tool_call = response.choices[0].message.tool_calls[0]
225-
chosen_func = next(f for f in functions if f.__name__ == tool_call.function.name)
226-
270+
chosen_func = next(
271+
f for f in functions if f.__name__ == tool_call.function.name
272+
)
273+
227274
args = json.loads(tool_call.function.arguments, strict=False)
228275
# Convert string lists back to actual lists
229276
for key, value in args.items():
230-
if isinstance(value, str) and value.startswith('[') and value.endswith(']'):
277+
if (
278+
isinstance(value, str)
279+
and value.startswith("[")
280+
and value.endswith("]")
281+
):
231282
try:
232283
args[key] = json.loads(value, strict=False)
233284
except json.JSONDecodeError:
234285
pass
235286

236287
return chosen_func, args
288+
237289
return self._retry_on_error(_execute)
238290

239291
def get_string(self, prompt: str, **kwargs) -> str:
240292
"""Get a simple string response from the LLM"""
293+
241294
def _execute():
242295
messages = self._build_messages(
243-
{"role": "user", "content": prompt},
244-
**kwargs
296+
{"role": "user", "content": prompt}, **kwargs
245297
)
246298
response = self.client.chat.completions.create(
247-
model=self.model,
248-
messages=messages
299+
model=self.model, messages=messages
249300
)
250301
return response.choices[0].message.content
302+
251303
return self._retry_on_error(_execute)
252304

253305
def get_stream(self, prompt: str, **kwargs) -> Iterable[str]:
254306
"""Get a streaming response from the LLM"""
307+
255308
def _execute():
256309
messages = self._build_messages(
257-
{"role": "user", "content": prompt},
258-
**kwargs
310+
{"role": "user", "content": prompt}, **kwargs
259311
)
260312
response = self.client.chat.completions.create(
261-
model=self.model,
262-
messages=messages,
263-
stream=True
313+
model=self.model, messages=messages, stream=True
264314
)
265315
for chunk in response:
266316
if chunk.choices[0].delta.content is not None:
267317
yield chunk.choices[0].delta.content
318+
268319
return self._retry_on_error(_execute)
269320

321+
270322
def create_openai_function_description(func: Callable) -> Dict[str, Any]:
271323
"""
272324
Takes a function and returns an OpenAI function description.
@@ -283,11 +335,7 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
283335
function_description = {
284336
"name": func.__name__,
285337
"description": docstring.split("\n")[0] if docstring else "",
286-
"parameters": {
287-
"type": "object",
288-
"properties": {},
289-
"required": []
290-
}
338+
"parameters": {"type": "object", "properties": {}, "required": []},
291339
}
292340

293341
for param_name, param in signature.parameters.items():
@@ -297,10 +345,10 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
297345

298346
# Try to get type from type annotation first
299347
if param.annotation != inspect.Parameter.empty:
300-
if hasattr(param.annotation, '__origin__'):
348+
if hasattr(param.annotation, "__origin__"):
301349
if param.annotation.__origin__ == list:
302350
param_info["type"] = "array"
303-
if hasattr(param.annotation, '__args__'):
351+
if hasattr(param.annotation, "__args__"):
304352
inner_type = param.annotation.__args__[0]
305353
if inner_type == str:
306354
param_info["items"] = {"type": "string"}
@@ -333,7 +381,9 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
333381
# Extract parameter description from docstring
334382
if docstring:
335383
param_pattern = re.compile(rf"{param_name}(\s*\([^)]*\))?:\s*(.*)")
336-
param_matches = [param_pattern.match(line.strip()) for line in docstring.split("\n")]
384+
param_matches = [
385+
param_pattern.match(line.strip()) for line in docstring.split("\n")
386+
]
337387
param_lines = [match.group(2) for match in param_matches if match]
338388
if param_lines:
339389
param_desc = param_lines[0].strip()
@@ -342,10 +392,3 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
342392
function_description["parameters"]["properties"][param_name] = param_info
343393

344394
return {"type": "function", "function": function_description}
345-
346-
347-
348-
349-
350-
351-

0 commit comments

Comments
 (0)