@@ -45,7 +45,7 @@ class FlatAI:
45
45
def __init__ (
46
46
self ,
47
47
client : Optional [openai .OpenAI ] = None ,
48
- model : str = "gpt-4 " ,
48
+ model : str = "gpt-4o " ,
49
49
retries : int = 3 ,
50
50
base_url : str = "https://api.openai.com/v1" ,
51
51
api_key : Optional [str ] = None ,
@@ -147,50 +147,21 @@ def _execute():
147
147
148
148
return self ._retry_on_error (_execute )
149
149
150
- def generate_object (self , schema_class : Type [BaseModel | Any ], ** kwargs ) -> Any :
150
+ def generate_object (self , schema_class : Type [BaseModel | List [ BaseModel ] ], ** kwargs ) -> Any :
151
151
"""Generate an object matching the provided schema"""
152
152
153
153
def _execute ():
154
- # Handle typing generics (List, Dict, etc)
155
- if hasattr (schema_class , "__origin__" ):
156
- if schema_class .__origin__ is list :
157
- item_type = schema_class .__args__ [0 ]
158
- if issubclass (item_type , BaseModel ):
159
- schema = {
160
- "type" : "array" ,
161
- "items" : item_type .model_json_schema (),
162
- }
163
- else :
164
- schema = {
165
- "type" : "array" ,
166
- "items" : {
167
- "type" : {
168
- str : "string" ,
169
- int : "integer" ,
170
- float : "number" ,
171
- bool : "boolean" ,
172
- }.get (item_type , "string" )
173
- },
174
- }
175
- else :
176
- raise ValueError (f"Unsupported generic type: { schema_class } " )
177
- # Handle basic Python types
178
- elif schema_class in (list , dict , str , int , float , bool ):
154
+ # if its a list of Pydantic models, we need to create a new schema for the array
155
+ if hasattr (schema_class , "__origin__" ) and schema_class .__origin__ is list :
179
156
schema = {
180
- "type" : {
181
- list : "array" ,
182
- dict : "object" ,
183
- str : "string" ,
184
- int : "integer" ,
185
- float : "number" ,
186
- bool : "boolean" ,
187
- }.get (schema_class , "string" )
157
+ "type" : "array" ,
158
+ "items" : schema_class .__args__ [0 ].model_json_schema ()
188
159
}
189
- if schema ["type" ] == "array" :
190
- schema ["items" ] = {"type" : "string" } # Default to string items
160
+ schema_name = schema_class .__args__ [0 ].__name__ + "Array"
191
161
# Handle Pydantic models
192
162
elif isinstance (schema_class , type ) and issubclass (schema_class , BaseModel ):
193
163
schema = schema_class .model_json_schema ()
164
+ schema_name = schema_class .__name__
194
165
else :
195
166
raise ValueError (f"Unsupported schema type: { schema_class } " )
196
167
@@ -202,18 +173,29 @@ def _execute():
202
173
** kwargs ,
203
174
)
204
175
205
-
206
- response = self .client .chat .completions .create (
207
- model = self .model ,
208
- response_format = {
209
- "type" : "json_schema" ,
210
- "json_schema" : {
211
- "name" : schema_class .__name__ ,
212
- "schema" : schema ,
176
+ # if Fireworks or Together use a different response format
177
+ if self .client .base_url in ["https://api.fireworks.ai/inference/v1" , "https://api.together.xyz/v1" ]:
178
+ response = self .client .chat .completions .create (
179
+ model = self .model ,
180
+ response_format = {
181
+ "type" : "json_object" ,
182
+ "schema" : schema
213
183
},
214
- },
215
- messages = messages ,
216
- )
184
+ messages = messages ,
185
+ )
186
+ else :
187
+ response = self .client .chat .completions .create (
188
+ model = self .model ,
189
+ response_format = {
190
+ "type" : "json_schema" ,
191
+ "json_schema" : {
192
+ "name" : schema_name ,
193
+ "schema" : schema
194
+ },
195
+ },
196
+ messages = messages ,
197
+ )
198
+
217
199
218
200
result = json .loads (response .choices [0 ].message .content )
219
201
0 commit comments