1515
1616import json
1717import logging
18+ import re
19+ import uuid
1820from collections import defaultdict
1921from typing import (
2022 TYPE_CHECKING ,
2830)
2931
3032from openai .types .chat import ChatCompletionMessageToolCall
33+ from openai .types .chat .chat_completion_message_tool_call import Function
3134from pydantic import BaseModel
3235
3336from camel .agents .base import BaseAgent
@@ -190,7 +193,7 @@ def __init__(
190193 tool .get_openai_tool_schema () for tool in all_tools
191194 ]
192195 self .model_backend .model_config_dict ['tools' ] = tool_schema_list
193-
196+ self . tool_schema_list = tool_schema_list
194197 self .model_config_dict = self .model_backend .model_config_dict
195198
196199 self .model_token_limit = token_limit or self .model_backend .token_limit
@@ -206,6 +209,56 @@ def __init__(
206209 self .response_terminators = response_terminators or []
207210 self .init_messages ()
208211
212+ # ruff: noqa: E501
213+ def _generate_tool_prompt (self , tool_schema_list : List [Dict ]) -> str :
214+ tool_prompts = []
215+
216+ for tool in tool_schema_list :
217+ tool_info = tool ['function' ]
218+ tool_name = tool_info ['name' ]
219+ tool_description = tool_info ['description' ]
220+ tool_json = json .dumps (tool_info , indent = 4 )
221+
222+ prompt = f"Use the function '{ tool_name } ' to '{ tool_description } ':\n { tool_json } \n "
223+ tool_prompts .append (prompt )
224+
225+ tool_prompt_str = "\n " .join (tool_prompts )
226+
227+ final_prompt = f'''
228+ # Tool prompt
229+ TOOL_PROMPT = f"""
230+ You have access to the following functions:
231+
232+ { tool_prompt_str }
233+
234+ If you choose to call a function ONLY reply in the following format with no prefix or suffix:
235+
236+ <function=example_function_name>{{"example_name": "example_value"}}</function>
237+
238+ Reminder:
239+ - Function calls MUST follow the specified format, start with <function= and end with </function>
240+ - Required parameters MUST be specified
241+ - Only call one function at a time
242+ - Put the entire function call reply on one line
243+ - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
244+ """
245+ '''
246+ return final_prompt
247+
248+ def _parse_tool_response (self , response : str ):
249+ function_regex = r"<function=(\w+)>(.*?)</function>"
250+ match = re .search (function_regex , response )
251+
252+ if match :
253+ function_name , args_string = match .groups ()
254+ try :
255+ args = json .loads (args_string )
256+ return {"function" : function_name , "arguments" : args }
257+ except json .JSONDecodeError as error :
258+ print (f"Error parsing function arguments: { error } " )
259+ return None
260+ return None
261+
209262 def reset (self ):
210263 r"""Resets the :obj:`ChatAgent` to its initial state and returns the
211264 stored messages.
@@ -367,89 +420,221 @@ def step(
367420 a boolean indicating whether the chat session has terminated,
368421 and information about the chat session.
369422 """
370- self .update_memory (input_message , OpenAIBackendRole .USER )
423+ if (
424+ isinstance (self .model_type , ModelType )
425+ and "lama" in self .model_type .value
426+ or isinstance (self .model_type , str )
427+ and "lama" in self .model_type
428+ ):
429+ if self .model_backend .model_config_dict ['tools' ]:
430+ tool_prompt = self ._generate_tool_prompt (self .tool_schema_list )
431+
432+ tool_sys_msg = BaseMessage .make_assistant_message (
433+ role_name = "Assistant" ,
434+ content = tool_prompt ,
435+ )
371436
372- tool_call_records : List [FunctionCallingRecord ] = []
373- while True :
374- # Check if token has exceeded
375- try :
376- openai_messages , num_tokens = self .memory .get_context ()
377- except RuntimeError as e :
378- return self ._step_token_exceed (
379- e .args [1 ], tool_call_records , "max_tokens_exceeded"
437+ self .update_memory (tool_sys_msg , OpenAIBackendRole .SYSTEM )
438+
439+ self .update_memory (input_message , OpenAIBackendRole .USER )
440+
441+ tool_call_records : List [FunctionCallingRecord ] = []
442+ while True :
443+ # Check if token has exceeded
444+ try :
445+ openai_messages , num_tokens = self .memory .get_context ()
446+ except RuntimeError as e :
447+ return self ._step_token_exceed (
448+ e .args [1 ], tool_call_records , "max_tokens_exceeded"
449+ )
450+
451+ (
452+ response ,
453+ output_messages ,
454+ finish_reasons ,
455+ usage_dict ,
456+ response_id ,
457+ ) = self ._step_model_response (openai_messages , num_tokens )
458+ # If the model response is not a function call, meaning the
459+ # model has generated a message response, break the loop
460+ if (
461+ not self .is_tools_added ()
462+ or not isinstance (response , ChatCompletion )
463+ or "</function>" not in response .choices [0 ].message .content # type: ignore[operator]
464+ ):
465+ break
466+
467+ parsed_content = self ._parse_tool_response (
468+ response .choices [0 ].message .content # type: ignore[arg-type]
380469 )
381470
382- (
383- response ,
471+ response .choices [0 ].message .tool_calls = [
472+ ChatCompletionMessageToolCall (
473+ id = str (uuid .uuid4 ()),
474+ function = Function (
475+ arguments = str (parsed_content ["arguments" ]).replace (
476+ "'" , '"'
477+ ),
478+ name = str (parsed_content ["function" ]),
479+ ),
480+ type = "function" ,
481+ )
482+ ]
483+
484+ # Check for external tool call
485+ tool_call_request = response .choices [0 ].message .tool_calls [0 ]
486+ if tool_call_request .function .name in self .external_tool_names :
487+ # if model calls an external tool, directly return the
488+ # request
489+ info = self ._step_get_info (
490+ output_messages ,
491+ finish_reasons ,
492+ usage_dict ,
493+ response_id ,
494+ tool_call_records ,
495+ num_tokens ,
496+ tool_call_request ,
497+ )
498+ return ChatAgentResponse (
499+ msgs = output_messages ,
500+ terminated = self .terminated ,
501+ info = info ,
502+ )
503+
504+ # Normal function calling
505+ tool_call_records .append (
506+ self ._step_tool_call_and_update (response )
507+ )
508+
509+ if (
510+ output_schema is not None
511+ and self .model_type .supports_tool_calling
512+ ):
513+ (
514+ output_messages ,
515+ finish_reasons ,
516+ usage_dict ,
517+ response_id ,
518+ tool_call ,
519+ num_tokens ,
520+ ) = self ._structure_output_with_function (output_schema )
521+ tool_call_records .append (tool_call )
522+
523+ info = self ._step_get_info (
384524 output_messages ,
385525 finish_reasons ,
386526 usage_dict ,
387527 response_id ,
388- ) = self ._step_model_response (openai_messages , num_tokens )
528+ tool_call_records ,
529+ num_tokens ,
530+ )
389531
390- # If the model response is not a function call, meaning the model
391- # has generated a message response, break the loop
392- if (
393- not self .is_tools_added ()
394- or not isinstance (response , ChatCompletion )
395- or response .choices [0 ].message .tool_calls is None
396- ):
397- break
532+ if len (output_messages ) == 1 :
533+ # Auto record if the output result is a single message
534+ self .record_message (output_messages [0 ])
535+ else :
536+ logger .warning (
537+ "Multiple messages returned in `step()`, message won't be "
538+ "recorded automatically. Please call `record_message()` "
539+ "to record the selected message manually."
540+ )
398541
399- # Check for external tool call
400- tool_call_request = response .choices [0 ].message .tool_calls [0 ]
401- if tool_call_request .function .name in self .external_tool_names :
402- # if model calls an external tool, directly return the request
403- info = self ._step_get_info (
542+ return ChatAgentResponse (
543+ msgs = output_messages , terminated = self .terminated , info = info
544+ )
545+
546+ else :
547+ self .update_memory (input_message , OpenAIBackendRole .USER )
548+
549+ tool_call_records : List [FunctionCallingRecord ] = [] # type: ignore[no-redef]
550+ while True :
551+ # Check if token has exceeded
552+ try :
553+ openai_messages , num_tokens = self .memory .get_context ()
554+ except RuntimeError as e :
555+ return self ._step_token_exceed (
556+ e .args [1 ], tool_call_records , "max_tokens_exceeded"
557+ )
558+
559+ (
560+ response ,
404561 output_messages ,
405562 finish_reasons ,
406563 usage_dict ,
407564 response_id ,
408- tool_call_records ,
409- num_tokens ,
410- tool_call_request ,
411- )
412- return ChatAgentResponse (
413- msgs = output_messages , terminated = self .terminated , info = info
565+ ) = self ._step_model_response (openai_messages , num_tokens )
566+ # If the model response is not a function call, meaning the
567+ # model has generated a message response, break the loop
568+ if (
569+ not self .is_tools_added ()
570+ or not isinstance (response , ChatCompletion )
571+ or response .choices [0 ].message .tool_calls is None
572+ ):
573+ break
574+
575+ # Check for external tool call
576+ tool_call_request = response .choices [0 ].message .tool_calls [0 ]
577+
578+ if tool_call_request .function .name in self .external_tool_names :
579+ # if model calls an external tool, directly return the
580+ # request
581+ info = self ._step_get_info (
582+ output_messages ,
583+ finish_reasons ,
584+ usage_dict ,
585+ response_id ,
586+ tool_call_records ,
587+ num_tokens ,
588+ tool_call_request ,
589+ )
590+ return ChatAgentResponse (
591+ msgs = output_messages ,
592+ terminated = self .terminated ,
593+ info = info ,
594+ )
595+
596+ # Normal function calling
597+ tool_call_records .append (
598+ self ._step_tool_call_and_update (response )
414599 )
415600
416- # Normal function calling
417- tool_call_records .append (self ._step_tool_call_and_update (response ))
601+ if (
602+ output_schema is not None
603+ and self .model_type .supports_tool_calling
604+ ):
605+ (
606+ output_messages ,
607+ finish_reasons ,
608+ usage_dict ,
609+ response_id ,
610+ tool_call ,
611+ num_tokens ,
612+ ) = self ._structure_output_with_function (output_schema )
613+ tool_call_records .append (tool_call )
418614
419- if output_schema is not None and self .model_type .supports_tool_calling :
420- (
615+ info = self ._step_get_info (
421616 output_messages ,
422617 finish_reasons ,
423618 usage_dict ,
424619 response_id ,
425- tool_call ,
620+ tool_call_records ,
426621 num_tokens ,
427- ) = self ._structure_output_with_function (output_schema )
428- tool_call_records .append (tool_call )
622+ )
429623
430- info = self ._step_get_info (
431- output_messages ,
432- finish_reasons ,
433- usage_dict ,
434- response_id ,
435- tool_call_records ,
436- num_tokens ,
437- )
624+ if len (output_messages ) == 1 :
625+ # Auto record if the output result is a single message
626+ self .record_message (output_messages [0 ])
627+ else :
628+ logger .warning (
629+ "Multiple messages returned in `step()`, message won't be "
630+ "recorded automatically. Please call `record_message()` "
631+ "to record the selected message manually."
632+ )
438633
439- if len (output_messages ) == 1 :
440- # Auto record if the output result is a single message
441- self .record_message (output_messages [0 ])
442- else :
443- logger .warning (
444- "Multiple messages returned in `step()`, message won't be "
445- "recorded automatically. Please call `record_message()` to "
446- "record the selected message manually."
634+ return ChatAgentResponse (
635+ msgs = output_messages , terminated = self .terminated , info = info
447636 )
448637
449- return ChatAgentResponse (
450- msgs = output_messages , terminated = self .terminated , info = info
451- )
452-
453638 async def step_async (
454639 self ,
455640 input_message : BaseMessage ,
0 commit comments