Skip to content

Commit e80dd0d

Browse files
committed
support for together and fireworks
1 parent c352e19 commit e80dd0d

File tree

2 files changed

+31
-49
lines changed

2 files changed

+31
-49
lines changed

flat_ai/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'flat-ai'
22
__package_name__ = 'flat_ai'
3-
__version__ = '0.3.1'
3+
__version__ = '0.3.2'
44
__description__ = 'F.L.A.T. (Frameworkless LLM Agent Thing) for building AI Agents'
55
__email__ = 'hello@mindsdb.com'
66
__author__ = 'Yours truly Jorge Torres and an LLM'

flat_ai/flat_ai.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class FlatAI:
4545
def __init__(
4646
self,
4747
client: Optional[openai.OpenAI] = None,
48-
model: str = "gpt-4",
48+
model: str = "gpt-4o",
4949
retries: int = 3,
5050
base_url: str = "https://api.openai.com/v1",
5151
api_key: Optional[str] = None,
@@ -147,50 +147,21 @@ def _execute():
147147

148148
return self._retry_on_error(_execute)
149149

150-
def generate_object(self, schema_class: Type[BaseModel | Any], **kwargs) -> Any:
150+
def generate_object(self, schema_class: Type[BaseModel | List[BaseModel]], **kwargs) -> Any:
151151
"""Generate an object matching the provided schema"""
152152

153153
def _execute():
154-
# Handle typing generics (List, Dict, etc)
155-
if hasattr(schema_class, "__origin__"):
156-
if schema_class.__origin__ is list:
157-
item_type = schema_class.__args__[0]
158-
if issubclass(item_type, BaseModel):
159-
schema = {
160-
"type": "array",
161-
"items": item_type.model_json_schema(),
162-
}
163-
else:
164-
schema = {
165-
"type": "array",
166-
"items": {
167-
"type": {
168-
str: "string",
169-
int: "integer",
170-
float: "number",
171-
bool: "boolean",
172-
}.get(item_type, "string")
173-
},
174-
}
175-
else:
176-
raise ValueError(f"Unsupported generic type: {schema_class}")
177-
# Handle basic Python types
178-
elif schema_class in (list, dict, str, int, float, bool):
154+
# if its a list of Pydantic models, we need to create a new schema for the array
155+
if hasattr(schema_class, "__origin__") and schema_class.__origin__ is list:
179156
schema = {
180-
"type": {
181-
list: "array",
182-
dict: "object",
183-
str: "string",
184-
int: "integer",
185-
float: "number",
186-
bool: "boolean",
187-
}.get(schema_class, "string")
157+
"type": "array",
158+
"items": schema_class.__args__[0].model_json_schema()
188159
}
189-
if schema["type"] == "array":
190-
schema["items"] = {"type": "string"} # Default to string items
160+
schema_name = schema_class.__args__[0].__name__+"Array"
191161
# Handle Pydantic models
192162
elif isinstance(schema_class, type) and issubclass(schema_class, BaseModel):
193163
schema = schema_class.model_json_schema()
164+
schema_name = schema_class.__name__
194165
else:
195166
raise ValueError(f"Unsupported schema type: {schema_class}")
196167

@@ -202,18 +173,29 @@ def _execute():
202173
**kwargs,
203174
)
204175

205-
206-
response = self.client.chat.completions.create(
207-
model=self.model,
208-
response_format={
209-
"type": "json_schema",
210-
"json_schema": {
211-
"name": schema_class.__name__,
212-
"schema": schema,
176+
# if Fireworks or Together use a different response format
177+
if self.client.base_url in ["https://api.fireworks.ai/inference/v1", "https://api.together.xyz/v1"]:
178+
response = self.client.chat.completions.create(
179+
model=self.model,
180+
response_format={
181+
"type": "json_object",
182+
"schema": schema
213183
},
214-
},
215-
messages=messages,
216-
)
184+
messages=messages,
185+
)
186+
else:
187+
response = self.client.chat.completions.create(
188+
model=self.model,
189+
response_format={
190+
"type": "json_schema",
191+
"json_schema": {
192+
"name": schema_name,
193+
"schema": schema
194+
},
195+
},
196+
messages=messages,
197+
)
198+
217199

218200
result = json.loads(response.choices[0].message.content)
219201

0 commit comments

Comments
 (0)