21
21
Author: Your Friendly Neighborhood AI Wrangler
22
22
"""
23
23
24
-
25
24
import inspect
26
25
import json
27
26
import re
28
- from typing import Any , Callable , Dict , List , Optional , Type , Union , Iterable
27
+ import time
28
+ from typing import Any , Callable , Dict , Iterable , List , Literal , Optional , Type , Union
29
+
29
30
import openai
30
31
from pydantic import BaseModel , Field
31
- from typing import Literal
32
- import time
32
+
33
33
from flat_ai .trace_llm import MyOpenAI
34
34
35
35
openai .OpenAI = MyOpenAI
36
+
37
+
36
38
class Boolean (BaseModel ):
37
- result : bool = Field (description = "The true/false result based on the question and context" )
39
+ result : bool = Field (
40
+ description = "The true/false result based on the question and context"
41
+ )
42
+
38
43
39
44
class FlatAI :
40
- def __init__ (self , client : Optional [openai .OpenAI ] = None , model : str = "gpt-4" , retries : int = 3 , base_url : str = "https://api.openai.com/v1" , api_key : Optional [str ] = None ):
45
+ def __init__ (
46
+ self ,
47
+ client : Optional [openai .OpenAI ] = None ,
48
+ model : str = "gpt-4" ,
49
+ retries : int = 3 ,
50
+ base_url : str = "https://api.openai.com/v1" ,
51
+ api_key : Optional [str ] = None ,
52
+ ):
41
53
if client :
42
54
self .client = client
43
55
elif api_key :
44
- self .client = openai .OpenAI (
45
- base_url = base_url ,
46
- api_key = api_key
47
- )
56
+ self .client = openai .OpenAI (base_url = base_url , api_key = api_key )
48
57
else :
49
58
raise ValueError ("Must provide either client or api_key" )
50
-
59
+
51
60
self .model = model
52
61
self .retries = retries
53
62
self ._context = {}
@@ -63,13 +72,15 @@ def _retry_on_error(self, func: Callable, *args, **kwargs) -> Any:
63
72
if attempt < self .retries - 1 :
64
73
time .sleep (1 * (attempt + 1 )) # Exponential backoff
65
74
continue
66
- raise Exception (f"Operation failed after { self .retries } attempts. Last error: { str (last_exception )} " )
75
+ raise Exception (
76
+ f"Operation failed after { self .retries } attempts. Last error: { str (last_exception )} "
77
+ )
67
78
68
79
def set_context (self , ** kwargs ):
69
80
"""Set the context for future LLM interactions"""
70
81
self ._context = kwargs
71
82
72
- def add_context (self , ** kwargs ):
83
+ def add_context (self , ** kwargs ):
73
84
"""Add additional context while preserving existing context"""
74
85
self ._context .update (kwargs )
75
86
@@ -85,15 +96,17 @@ def delete_from_context(self, *keys):
85
96
def _build_messages (self , * message_parts , ** kwargs ) -> List [Dict [str , str ]]:
86
97
"""Build message list with context as system message if present"""
87
98
messages = []
88
-
99
+
89
100
if self ._context :
90
101
context_dict = {}
91
102
for key , value in self ._context .items ():
92
103
if isinstance (value , BaseModel ):
93
104
context_dict [key ] = json .loads (value .model_dump_json ())
94
105
else :
95
106
context_dict [key ] = str (value )
96
- messages .append ({"role" : "system" , "content" : json .dumps (context_dict , indent = 2 )})
107
+ messages .append (
108
+ {"role" : "system" , "content" : json .dumps (context_dict , indent = 2 )}
109
+ )
97
110
98
111
if kwargs :
99
112
extra_context_dict = {}
@@ -102,34 +115,41 @@ def _build_messages(self, *message_parts, **kwargs) -> List[Dict[str, str]]:
102
115
extra_context_dict [key ] = json .loads (value .model_dump_json ())
103
116
else :
104
117
extra_context_dict [key ] = str (value )
105
- messages .append ({"role" : "system" , "content" : json .dumps (extra_context_dict , indent = 2 )})
118
+ messages .append (
119
+ {"role" : "system" , "content" : json .dumps (extra_context_dict , indent = 2 )}
120
+ )
106
121
107
122
messages .extend (message_parts )
108
123
return messages
109
124
110
125
def is_true (self , _question : str , ** kwargs ) -> bool :
111
126
class IsItTrue (BaseModel ):
112
127
is_it_true : bool
128
+
113
129
"""Ask a yes/no question and get a boolean response"""
114
130
ret = self .generate_object (IsItTrue , question = _question , ** kwargs )
115
131
return ret .is_it_true
116
132
117
133
def classify (self , options : Dict [str , str ], ** kwargs ) -> str :
118
134
"""Get a key from provided options based on context"""
135
+
119
136
class Classification (BaseModel ):
120
- choice : str = Field (description = "The selected classification key" , enum = list (options .keys ()))
137
+ choice : str = Field (
138
+ description = "The selected classification key" , enum = list (options .keys ())
139
+ )
121
140
122
141
def _execute ():
123
142
if not options :
124
143
raise ValueError ("Options dictionary cannot be empty" )
125
-
144
+
126
145
result = self .generate_object (Classification , options = options )
127
146
return result .choice
128
-
147
+
129
148
return self ._retry_on_error (_execute )
130
149
131
150
def generate_object (self , schema_class : Type [BaseModel | Any ], ** kwargs ) -> Any :
132
151
"""Generate an object matching the provided schema"""
152
+
133
153
def _execute ():
134
154
# Handle typing generics (List, Dict, etc)
135
155
if hasattr (schema_class , "__origin__" ):
@@ -138,7 +158,7 @@ def _execute():
138
158
if issubclass (item_type , BaseModel ):
139
159
schema = {
140
160
"type" : "array" ,
141
- "items" : item_type .model_json_schema ()
161
+ "items" : item_type .model_json_schema (),
142
162
}
143
163
else :
144
164
schema = {
@@ -148,9 +168,9 @@ def _execute():
148
168
str : "string" ,
149
169
int : "integer" ,
150
170
float : "number" ,
151
- bool : "boolean"
171
+ bool : "boolean" ,
152
172
}.get (item_type , "string" )
153
- }
173
+ },
154
174
}
155
175
else :
156
176
raise ValueError (f"Unsupported generic type: { schema_class } " )
@@ -163,7 +183,7 @@ def _execute():
163
183
str : "string" ,
164
184
int : "integer" ,
165
185
float : "number" ,
166
- bool : "boolean"
186
+ bool : "boolean" ,
167
187
}.get (schema_class , "string" )
168
188
}
169
189
if schema ["type" ] == "array" :
@@ -175,98 +195,130 @@ def _execute():
175
195
raise ValueError (f"Unsupported schema type: { schema_class } " )
176
196
177
197
messages = self ._build_messages (
178
- {"role" : "user" , "content" : "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data." },
179
- ** kwargs
198
+ {
199
+ "role" : "user" ,
200
+ "content" : "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data." ,
201
+ },
202
+ ** kwargs ,
180
203
)
181
204
182
- response = self .client .chat .completions .create (
183
- model = self .model ,
184
- response_format = {"type" : "json_object" , "schema" : schema },
185
- messages = messages
186
- )
187
-
205
+ if isinstance (self .client , MyOpenAI ):
206
+ response = self .client .chat .completions .create (
207
+ model = self .model ,
208
+ response_format = {
209
+ "type" : "json_schema" ,
210
+ "json_schema" : {
211
+ "name" : schema_class .__name__ ,
212
+ "schema" : schema ,
213
+ },
214
+ },
215
+ messages = messages ,
216
+ )
217
+ else :
218
+ response = self .client .chat .completions .create (
219
+ model = self .model ,
220
+ response_format = {"type" : "json_object" , "schema" : schema },
221
+ messages = messages ,
222
+ )
223
+
188
224
result = json .loads (response .choices [0 ].message .content )
189
-
225
+
190
226
# Handle list of Pydantic models
191
- if (hasattr (schema_class , "__origin__" ) and
192
- schema_class .__origin__ is list and
193
- issubclass (schema_class .__args__ [0 ], BaseModel )):
194
- return [schema_class .__args__ [0 ].model_validate (item ) for item in result ]
227
+ if (
228
+ hasattr (schema_class , "__origin__" )
229
+ and schema_class .__origin__ is list
230
+ and issubclass (schema_class .__args__ [0 ], BaseModel )
231
+ ):
232
+ return [
233
+ schema_class .__args__ [0 ].model_validate (item ) for item in result
234
+ ]
195
235
# Handle single Pydantic model
196
236
elif isinstance (schema_class , type ) and issubclass (schema_class , BaseModel ):
197
237
return schema_class .model_validate (result )
198
-
238
+
199
239
return result
240
+
200
241
return self ._retry_on_error (_execute )
201
242
202
243
def call_function (self , func : Callable , ** kwargs ) -> Any :
203
244
"""Call a function with AI-determined arguments"""
204
245
func , args = self .pick_a_function ("" , [func ], ** kwargs )
205
246
return func (** args )
206
247
207
- def pick_a_function (self , instructions : str , functions : List [Callable ], ** kwargs ) -> tuple [Callable , Dict ]:
248
+ def pick_a_function (
249
+ self , instructions : str , functions : List [Callable ], ** kwargs
250
+ ) -> tuple [Callable , Dict ]:
208
251
"""Pick appropriate function and arguments based on instructions"""
252
+
209
253
def _execute ():
210
254
tools = [create_openai_function_description (func ) for func in functions ]
211
255
212
256
messages = self ._build_messages (
213
257
{"role" : "system" , "content" : instructions },
214
- {"role" : "user" , "content" : "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome." },
215
- ** kwargs
258
+ {
259
+ "role" : "user" ,
260
+ "content" : "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome." ,
261
+ },
262
+ ** kwargs ,
216
263
)
217
264
218
265
response = self .client .chat .completions .create (
219
- model = self .model ,
220
- messages = messages ,
221
- tools = tools
266
+ model = self .model , messages = messages , tools = tools
222
267
)
223
268
224
269
tool_call = response .choices [0 ].message .tool_calls [0 ]
225
- chosen_func = next (f for f in functions if f .__name__ == tool_call .function .name )
226
-
270
+ chosen_func = next (
271
+ f for f in functions if f .__name__ == tool_call .function .name
272
+ )
273
+
227
274
args = json .loads (tool_call .function .arguments , strict = False )
228
275
# Convert string lists back to actual lists
229
276
for key , value in args .items ():
230
- if isinstance (value , str ) and value .startswith ('[' ) and value .endswith (']' ):
277
+ if (
278
+ isinstance (value , str )
279
+ and value .startswith ("[" )
280
+ and value .endswith ("]" )
281
+ ):
231
282
try :
232
283
args [key ] = json .loads (value , strict = False )
233
284
except json .JSONDecodeError :
234
285
pass
235
286
236
287
return chosen_func , args
288
+
237
289
return self ._retry_on_error (_execute )
238
290
239
291
def get_string (self , prompt : str , ** kwargs ) -> str :
240
292
"""Get a simple string response from the LLM"""
293
+
241
294
def _execute ():
242
295
messages = self ._build_messages (
243
- {"role" : "user" , "content" : prompt },
244
- ** kwargs
296
+ {"role" : "user" , "content" : prompt }, ** kwargs
245
297
)
246
298
response = self .client .chat .completions .create (
247
- model = self .model ,
248
- messages = messages
299
+ model = self .model , messages = messages
249
300
)
250
301
return response .choices [0 ].message .content
302
+
251
303
return self ._retry_on_error (_execute )
252
304
253
305
def get_stream (self , prompt : str , ** kwargs ) -> Iterable [str ]:
254
306
"""Get a streaming response from the LLM"""
307
+
255
308
def _execute ():
256
309
messages = self ._build_messages (
257
- {"role" : "user" , "content" : prompt },
258
- ** kwargs
310
+ {"role" : "user" , "content" : prompt }, ** kwargs
259
311
)
260
312
response = self .client .chat .completions .create (
261
- model = self .model ,
262
- messages = messages ,
263
- stream = True
313
+ model = self .model , messages = messages , stream = True
264
314
)
265
315
for chunk in response :
266
316
if chunk .choices [0 ].delta .content is not None :
267
317
yield chunk .choices [0 ].delta .content
318
+
268
319
return self ._retry_on_error (_execute )
269
320
321
+
270
322
def create_openai_function_description (func : Callable ) -> Dict [str , Any ]:
271
323
"""
272
324
Takes a function and returns an OpenAI function description.
@@ -283,11 +335,7 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
283
335
function_description = {
284
336
"name" : func .__name__ ,
285
337
"description" : docstring .split ("\n " )[0 ] if docstring else "" ,
286
- "parameters" : {
287
- "type" : "object" ,
288
- "properties" : {},
289
- "required" : []
290
- }
338
+ "parameters" : {"type" : "object" , "properties" : {}, "required" : []},
291
339
}
292
340
293
341
for param_name , param in signature .parameters .items ():
@@ -297,10 +345,10 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
297
345
298
346
# Try to get type from type annotation first
299
347
if param .annotation != inspect .Parameter .empty :
300
- if hasattr (param .annotation , ' __origin__' ):
348
+ if hasattr (param .annotation , " __origin__" ):
301
349
if param .annotation .__origin__ == list :
302
350
param_info ["type" ] = "array"
303
- if hasattr (param .annotation , ' __args__' ):
351
+ if hasattr (param .annotation , " __args__" ):
304
352
inner_type = param .annotation .__args__ [0 ]
305
353
if inner_type == str :
306
354
param_info ["items" ] = {"type" : "string" }
@@ -333,7 +381,9 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
333
381
# Extract parameter description from docstring
334
382
if docstring :
335
383
param_pattern = re .compile (rf"{ param_name } (\s*\([^)]*\))?:\s*(.*)" )
336
- param_matches = [param_pattern .match (line .strip ()) for line in docstring .split ("\n " )]
384
+ param_matches = [
385
+ param_pattern .match (line .strip ()) for line in docstring .split ("\n " )
386
+ ]
337
387
param_lines = [match .group (2 ) for match in param_matches if match ]
338
388
if param_lines :
339
389
param_desc = param_lines [0 ].strip ()
@@ -342,10 +392,3 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
342
392
function_description ["parameters" ]["properties" ][param_name ] = param_info
343
393
344
394
return {"type" : "function" , "function" : function_description }
345
-
346
-
347
-
348
-
349
-
350
-
351
-
0 commit comments