Skip to content

Commit 5f4e4d6

Browse files
committed
Better list management
1 parent 4cbf46b commit 5f4e4d6

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

flat_ai/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'flat-ai'
22
__package_name__ = 'flat_ai'
3-
__version__ = '0.3.4'
3+
__version__ = '0.3.5'
44
__description__ = 'F.L.A.T. (Frameworkless LLM Agent Thing) for building AI Agents'
55
__email__ = 'hello@mindsdb.com'
66
__author__ = 'Yours truly Jorge Torres and an LLM'

flat_ai/flat_ai.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import json
2626
import re
2727
import time
28-
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Type, Union
28+
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, get_origin
2929

3030
import openai
3131
from pydantic import BaseModel, Field
@@ -148,31 +148,34 @@ def _execute():
148148

149149
return self._retry_on_error(_execute)
150150

151+
151152
def generate_object(self, schema_class: Type[BaseModel | List[BaseModel]], **kwargs) -> Any:
152153
"""Generate an object matching the provided schema"""
153-
154+
155+
# LIST OF Pydantic models
156+
class ObjectList(BaseModel):
157+
items: List[schema_class]
158+
# if its a list of Pydantic models, we need to create a new schema for the array
159+
if get_origin(schema_class) is list:
160+
is_list = True
161+
schema_name = schema_class.__args__[0].__name__+"Array"
162+
schema = ObjectList.model_json_schema()
163+
# Handle Pydantic models
164+
else:
165+
is_list = False
166+
schema = schema_class.model_json_schema()
167+
schema_name = schema_class.__name__
168+
169+
messages = self._build_messages(
170+
{
171+
"role": "user",
172+
"content": "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data.",
173+
},
174+
**kwargs,
175+
)
176+
154177
def _execute():
155-
# if its a list of Pydantic models, we need to create a new schema for the array
156-
if hasattr(schema_class, "__origin__") and schema_class.__origin__ is list:
157-
schema = {
158-
"type": "array",
159-
"items": schema_class.__args__[0].model_json_schema()
160-
}
161-
schema_name = schema_class.__args__[0].__name__+"Array"
162-
# Handle Pydantic models
163-
else:
164-
schema = schema_class.model_json_schema()
165-
schema_name = schema_class.__name__
166178

167-
168-
messages = self._build_messages(
169-
{
170-
"role": "user",
171-
"content": "Based on the provided context and information, generate a complete and accurate object that precisely matches the schema. Use all relevant details to populate the fields with meaningful, appropriate values that best represent the data.",
172-
},
173-
**kwargs,
174-
)
175-
176179
# if Fireworks or Together use a different response format
177180
if self.base_url in ["https://api.fireworks.ai/inference/v1", "https://api.together.xyz/v1"]:
178181
response = self.client.chat.completions.create(
@@ -200,20 +203,14 @@ def _execute():
200203
result = json.loads(response.choices[0].message.content)
201204

202205
# Handle list of Pydantic models
203-
if (
204-
hasattr(schema_class, "__origin__")
205-
and schema_class.__origin__ is list
206-
and issubclass(schema_class.__args__[0], BaseModel)
207-
):
206+
if is_list:
208207
return [
209-
schema_class.__args__[0].model_validate(item) for item in result
208+
ObjectList.model_validate(result).items for item in result
210209
]
211210
# Handle single Pydantic model
212-
elif isinstance(schema_class, type) and issubclass(schema_class, BaseModel):
211+
else:
213212
return schema_class.model_validate(result)
214213

215-
return result
216-
217214
return self._retry_on_error(_execute)
218215

219216
def call_function(self, func: Callable, **kwargs) -> Any:
@@ -225,19 +222,20 @@ def pick_a_function(
225222
self, instructions: str, functions: List[Callable], **kwargs
226223
) -> tuple[Callable, Dict]:
227224
"""Pick appropriate function and arguments based on instructions"""
225+
226+
tools = [create_openai_function_description(func) for func in functions]
227+
228+
messages = self._build_messages(
229+
{"role": "system", "content": instructions},
230+
{
231+
"role": "user",
232+
"content": "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome.",
233+
},
234+
**kwargs,
235+
)
228236

229237
def _execute():
230-
tools = [create_openai_function_description(func) for func in functions]
231-
232-
messages = self._build_messages(
233-
{"role": "system", "content": instructions},
234-
{
235-
"role": "user",
236-
"content": "Based on all the provided context and information, analyze and select the most appropriate function from the available options. Then, determine and specify the optimal parameters for that function to achieve the intended outcome.",
237-
},
238-
**kwargs,
239-
)
240-
238+
241239
response = self.client.chat.completions.create(
242240
model=self.model, messages=messages, tools=tools
243241
)
@@ -368,3 +366,5 @@ def create_openai_function_description(func: Callable) -> Dict[str, Any]:
368366
function_description["parameters"]["properties"][param_name] = param_info
369367

370368
return {"type": "function", "function": function_description}
369+
370+

0 commit comments

Comments
 (0)