Skip to content

Commit eb004bc

Browse files
committed
lists working for all providers
1 parent 5f4e4d6 commit eb004bc

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ __pycache__/
55
*.logs
66
# C extensions
77
*.so
8-
8+
test_local.py
99
# Distribution / packaging
1010
.Python
1111
build/

flat_ai/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'flat-ai'
22
__package_name__ = 'flat_ai'
3-
__version__ = '0.3.5'
3+
__version__ = '0.3.6'
44
__description__ = 'F.L.A.T. (Frameworkless LLM Agent Thing) for building AI Agents'
55
__email__ = 'hello@mindsdb.com'
66
__author__ = 'Yours truly Jorge Torres and an LLM'

flat_ai/flat_ai.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ def generate_object(self, schema_class: Type[BaseModel | List[BaseModel]], **kwa
153153
"""Generate an object matching the provided schema"""
154154

155155
# LIST OF Pydantic models
156-
class ObjectList(BaseModel):
157-
items: List[schema_class]
156+
class ObjectArray(BaseModel):
157+
items: List[schema_class.__args__[0]] if hasattr(schema_class, "__origin__") and schema_class.__origin__ == list else List[schema_class]
158+
158159
# if its a list of Pydantic models, we need to create a new schema for the array
159-
if get_origin(schema_class) is list:
160+
if hasattr(schema_class, "__origin__") and schema_class.__origin__ == list:
160161
is_list = True
161162
schema_name = schema_class.__args__[0].__name__+"Array"
162-
schema = ObjectList.model_json_schema()
163+
schema = ObjectArray.model_json_schema()
163164
# Handle Pydantic models
164165
else:
165166
is_list = False
@@ -204,9 +205,8 @@ def _execute():
204205

205206
# Handle list of Pydantic models
206207
if is_list:
207-
return [
208-
ObjectList.model_validate(result).items for item in result
209-
]
208+
items = ObjectArray.model_validate(result).items
209+
return items
210210
# Handle single Pydantic model
211211
else:
212212
return schema_class.model_validate(result)

0 commit comments

Comments
 (0)