Skip to content

Commit 2a376a2

Browse files
committed
making llama models retry if function call fails, because sometimes it doesnt actually return anything when it should
1 parent a885898 commit 2a376a2

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

flat_ai/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'flat-ai'
22
__package_name__ = 'flat_ai'
3-
__version__ = '0.3.7'
3+
__version__ = '0.3.8'
44
__description__ = 'F.L.A.T. (Frameworkless LLM Agent Thing) for building AI Agents'
55
__email__ = 'hello@mindsdb.com'
66
__author__ = 'Yours truly Jorge Torres and an LLM'

flat_ai/flat_ai.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -251,32 +251,43 @@ def _execute():
251251
model=self.model, messages=messages, tools=tools, tool_choice = tool_choice
252252
)
253253

254-
tool_call = response.choices[0].message.tool_calls[0]
255-
if tool_choice == "auto" and tool_call.function.name == "None":
256-
return functions_to_call
254+
# Process all tool calls if there are any
255+
tool_calls = response.choices[0].message.tool_calls or []
256+
content = response.choices[0].message.content
257257

258-
chosen_func = next(
259-
f for f in functions if f.__name__ == tool_call.function.name
260-
)
258+
if tool_calls == [] and tool_choice == "required" and content:
259+
raise Exception("No tools found") # This is an error that can be handled by the caller
261260

262-
args = json.loads(tool_call.function.arguments, strict=False)
263-
# Convert string lists back to actual lists
264-
for key, value in args.items():
265-
if (
266-
isinstance(value, str)
267-
and value.startswith("[")
268-
and value.endswith("]")
269-
):
270-
try:
271-
args[key] = json.loads(value, strict=False)
272-
except json.JSONDecodeError:
273-
pass
274-
275-
function_call = FunctionCall(
276-
function=chosen_func,
277-
arguments=args
278-
)
279-
functions_to_call.functions.append(function_call)
261+
# If _multiple_functions is False, only process the first tool call
262+
if not _multiple_functions and len(tool_calls) > 0:
263+
tool_calls = [tool_calls[0]]
264+
265+
for tool_call in tool_calls:
266+
if tool_choice == "auto" and tool_call.function.name == "None":
267+
continue
268+
269+
chosen_func = next(
270+
f for f in functions if f.__name__ == tool_call.function.name
271+
)
272+
273+
args = json.loads(tool_call.function.arguments, strict=False)
274+
# Convert string lists back to actual lists
275+
for key, value in args.items():
276+
if (
277+
isinstance(value, str)
278+
and value.startswith("[")
279+
and value.endswith("]")
280+
):
281+
try:
282+
args[key] = json.loads(value, strict=False)
283+
except json.JSONDecodeError:
284+
pass
285+
286+
function_call = FunctionCall(
287+
function=chosen_func,
288+
arguments=args
289+
)
290+
functions_to_call.functions.append(function_call)
280291

281292
return functions_to_call
282293

0 commit comments

Comments
 (0)