15
15
16
16
import json
17
17
import logging
18
+ import re
19
+ import uuid
18
20
from collections import defaultdict
19
21
from typing import (
20
22
TYPE_CHECKING ,
28
30
)
29
31
30
32
from openai .types .chat import ChatCompletionMessageToolCall
33
+ from openai .types .chat .chat_completion_message_tool_call import Function
31
34
from pydantic import BaseModel
32
35
33
36
from camel .agents .base import BaseAgent
@@ -190,7 +193,7 @@ def __init__(
190
193
tool .get_openai_tool_schema () for tool in all_tools
191
194
]
192
195
self .model_backend .model_config_dict ['tools' ] = tool_schema_list
193
-
196
+ self . tool_schema_list = tool_schema_list
194
197
self .model_config_dict = self .model_backend .model_config_dict
195
198
196
199
self .model_token_limit = token_limit or self .model_backend .token_limit
@@ -206,6 +209,56 @@ def __init__(
206
209
self .response_terminators = response_terminators or []
207
210
self .init_messages ()
208
211
212
+ # ruff: noqa: E501
213
+ def _generate_tool_prompt (self , tool_schema_list : List [Dict ]) -> str :
214
+ tool_prompts = []
215
+
216
+ for tool in tool_schema_list :
217
+ tool_info = tool ['function' ]
218
+ tool_name = tool_info ['name' ]
219
+ tool_description = tool_info ['description' ]
220
+ tool_json = json .dumps (tool_info , indent = 4 )
221
+
222
+ prompt = f"Use the function '{ tool_name } ' to '{ tool_description } ':\n { tool_json } \n "
223
+ tool_prompts .append (prompt )
224
+
225
+ tool_prompt_str = "\n " .join (tool_prompts )
226
+
227
+ final_prompt = f'''
228
+ # Tool prompt
229
+ TOOL_PROMPT = f"""
230
+ You have access to the following functions:
231
+
232
+ { tool_prompt_str }
233
+
234
+ If you choose to call a function ONLY reply in the following format with no prefix or suffix:
235
+
236
+ <function=example_function_name>{{"example_name": "example_value"}}</function>
237
+
238
+ Reminder:
239
+ - Function calls MUST follow the specified format, start with <function= and end with </function>
240
+ - Required parameters MUST be specified
241
+ - Only call one function at a time
242
+ - Put the entire function call reply on one line
243
+ - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
244
+ """
245
+ '''
246
+ return final_prompt
247
+
248
+ def _parse_tool_response (self , response : str ):
249
+ function_regex = r"<function=(\w+)>(.*?)</function>"
250
+ match = re .search (function_regex , response )
251
+
252
+ if match :
253
+ function_name , args_string = match .groups ()
254
+ try :
255
+ args = json .loads (args_string )
256
+ return {"function" : function_name , "arguments" : args }
257
+ except json .JSONDecodeError as error :
258
+ print (f"Error parsing function arguments: { error } " )
259
+ return None
260
+ return None
261
+
209
262
def reset (self ):
210
263
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
211
264
stored messages.
@@ -367,89 +420,221 @@ def step(
367
420
a boolean indicating whether the chat session has terminated,
368
421
and information about the chat session.
369
422
"""
370
- self .update_memory (input_message , OpenAIBackendRole .USER )
423
+ if (
424
+ isinstance (self .model_type , ModelType )
425
+ and "lama" in self .model_type .value
426
+ or isinstance (self .model_type , str )
427
+ and "lama" in self .model_type
428
+ ):
429
+ if self .model_backend .model_config_dict ['tools' ]:
430
+ tool_prompt = self ._generate_tool_prompt (self .tool_schema_list )
431
+
432
+ tool_sys_msg = BaseMessage .make_assistant_message (
433
+ role_name = "Assistant" ,
434
+ content = tool_prompt ,
435
+ )
371
436
372
- tool_call_records : List [FunctionCallingRecord ] = []
373
- while True :
374
- # Check if token has exceeded
375
- try :
376
- openai_messages , num_tokens = self .memory .get_context ()
377
- except RuntimeError as e :
378
- return self ._step_token_exceed (
379
- e .args [1 ], tool_call_records , "max_tokens_exceeded"
437
+ self .update_memory (tool_sys_msg , OpenAIBackendRole .SYSTEM )
438
+
439
+ self .update_memory (input_message , OpenAIBackendRole .USER )
440
+
441
+ tool_call_records : List [FunctionCallingRecord ] = []
442
+ while True :
443
+ # Check if token has exceeded
444
+ try :
445
+ openai_messages , num_tokens = self .memory .get_context ()
446
+ except RuntimeError as e :
447
+ return self ._step_token_exceed (
448
+ e .args [1 ], tool_call_records , "max_tokens_exceeded"
449
+ )
450
+
451
+ (
452
+ response ,
453
+ output_messages ,
454
+ finish_reasons ,
455
+ usage_dict ,
456
+ response_id ,
457
+ ) = self ._step_model_response (openai_messages , num_tokens )
458
+ # If the model response is not a function call, meaning the
459
+ # model has generated a message response, break the loop
460
+ if (
461
+ not self .is_tools_added ()
462
+ or not isinstance (response , ChatCompletion )
463
+ or "</function>" not in response .choices [0 ].message .content # type: ignore[operator]
464
+ ):
465
+ break
466
+
467
+ parsed_content = self ._parse_tool_response (
468
+ response .choices [0 ].message .content # type: ignore[arg-type]
380
469
)
381
470
382
- (
383
- response ,
471
+ response .choices [0 ].message .tool_calls = [
472
+ ChatCompletionMessageToolCall (
473
+ id = str (uuid .uuid4 ()),
474
+ function = Function (
475
+ arguments = str (parsed_content ["arguments" ]).replace (
476
+ "'" , '"'
477
+ ),
478
+ name = str (parsed_content ["function" ]),
479
+ ),
480
+ type = "function" ,
481
+ )
482
+ ]
483
+
484
+ # Check for external tool call
485
+ tool_call_request = response .choices [0 ].message .tool_calls [0 ]
486
+ if tool_call_request .function .name in self .external_tool_names :
487
+ # if model calls an external tool, directly return the
488
+ # request
489
+ info = self ._step_get_info (
490
+ output_messages ,
491
+ finish_reasons ,
492
+ usage_dict ,
493
+ response_id ,
494
+ tool_call_records ,
495
+ num_tokens ,
496
+ tool_call_request ,
497
+ )
498
+ return ChatAgentResponse (
499
+ msgs = output_messages ,
500
+ terminated = self .terminated ,
501
+ info = info ,
502
+ )
503
+
504
+ # Normal function calling
505
+ tool_call_records .append (
506
+ self ._step_tool_call_and_update (response )
507
+ )
508
+
509
+ if (
510
+ output_schema is not None
511
+ and self .model_type .supports_tool_calling
512
+ ):
513
+ (
514
+ output_messages ,
515
+ finish_reasons ,
516
+ usage_dict ,
517
+ response_id ,
518
+ tool_call ,
519
+ num_tokens ,
520
+ ) = self ._structure_output_with_function (output_schema )
521
+ tool_call_records .append (tool_call )
522
+
523
+ info = self ._step_get_info (
384
524
output_messages ,
385
525
finish_reasons ,
386
526
usage_dict ,
387
527
response_id ,
388
- ) = self ._step_model_response (openai_messages , num_tokens )
528
+ tool_call_records ,
529
+ num_tokens ,
530
+ )
389
531
390
- # If the model response is not a function call, meaning the model
391
- # has generated a message response, break the loop
392
- if (
393
- not self .is_tools_added ()
394
- or not isinstance (response , ChatCompletion )
395
- or response .choices [0 ].message .tool_calls is None
396
- ):
397
- break
532
+ if len (output_messages ) == 1 :
533
+ # Auto record if the output result is a single message
534
+ self .record_message (output_messages [0 ])
535
+ else :
536
+ logger .warning (
537
+ "Multiple messages returned in `step()`, message won't be "
538
+ "recorded automatically. Please call `record_message()` "
539
+ "to record the selected message manually."
540
+ )
398
541
399
- # Check for external tool call
400
- tool_call_request = response .choices [0 ].message .tool_calls [0 ]
401
- if tool_call_request .function .name in self .external_tool_names :
402
- # if model calls an external tool, directly return the request
403
- info = self ._step_get_info (
542
+ return ChatAgentResponse (
543
+ msgs = output_messages , terminated = self .terminated , info = info
544
+ )
545
+
546
+ else :
547
+ self .update_memory (input_message , OpenAIBackendRole .USER )
548
+
549
+ tool_call_records : List [FunctionCallingRecord ] = [] # type: ignore[no-redef]
550
+ while True :
551
+ # Check if token has exceeded
552
+ try :
553
+ openai_messages , num_tokens = self .memory .get_context ()
554
+ except RuntimeError as e :
555
+ return self ._step_token_exceed (
556
+ e .args [1 ], tool_call_records , "max_tokens_exceeded"
557
+ )
558
+
559
+ (
560
+ response ,
404
561
output_messages ,
405
562
finish_reasons ,
406
563
usage_dict ,
407
564
response_id ,
408
- tool_call_records ,
409
- num_tokens ,
410
- tool_call_request ,
411
- )
412
- return ChatAgentResponse (
413
- msgs = output_messages , terminated = self .terminated , info = info
565
+ ) = self ._step_model_response (openai_messages , num_tokens )
566
+ # If the model response is not a function call, meaning the
567
+ # model has generated a message response, break the loop
568
+ if (
569
+ not self .is_tools_added ()
570
+ or not isinstance (response , ChatCompletion )
571
+ or response .choices [0 ].message .tool_calls is None
572
+ ):
573
+ break
574
+
575
+ # Check for external tool call
576
+ tool_call_request = response .choices [0 ].message .tool_calls [0 ]
577
+
578
+ if tool_call_request .function .name in self .external_tool_names :
579
+ # if model calls an external tool, directly return the
580
+ # request
581
+ info = self ._step_get_info (
582
+ output_messages ,
583
+ finish_reasons ,
584
+ usage_dict ,
585
+ response_id ,
586
+ tool_call_records ,
587
+ num_tokens ,
588
+ tool_call_request ,
589
+ )
590
+ return ChatAgentResponse (
591
+ msgs = output_messages ,
592
+ terminated = self .terminated ,
593
+ info = info ,
594
+ )
595
+
596
+ # Normal function calling
597
+ tool_call_records .append (
598
+ self ._step_tool_call_and_update (response )
414
599
)
415
600
416
- # Normal function calling
417
- tool_call_records .append (self ._step_tool_call_and_update (response ))
601
+ if (
602
+ output_schema is not None
603
+ and self .model_type .supports_tool_calling
604
+ ):
605
+ (
606
+ output_messages ,
607
+ finish_reasons ,
608
+ usage_dict ,
609
+ response_id ,
610
+ tool_call ,
611
+ num_tokens ,
612
+ ) = self ._structure_output_with_function (output_schema )
613
+ tool_call_records .append (tool_call )
418
614
419
- if output_schema is not None and self .model_type .supports_tool_calling :
420
- (
615
+ info = self ._step_get_info (
421
616
output_messages ,
422
617
finish_reasons ,
423
618
usage_dict ,
424
619
response_id ,
425
- tool_call ,
620
+ tool_call_records ,
426
621
num_tokens ,
427
- ) = self ._structure_output_with_function (output_schema )
428
- tool_call_records .append (tool_call )
622
+ )
429
623
430
- info = self ._step_get_info (
431
- output_messages ,
432
- finish_reasons ,
433
- usage_dict ,
434
- response_id ,
435
- tool_call_records ,
436
- num_tokens ,
437
- )
624
+ if len (output_messages ) == 1 :
625
+ # Auto record if the output result is a single message
626
+ self .record_message (output_messages [0 ])
627
+ else :
628
+ logger .warning (
629
+ "Multiple messages returned in `step()`, message won't be "
630
+ "recorded automatically. Please call `record_message()` "
631
+ "to record the selected message manually."
632
+ )
438
633
439
- if len (output_messages ) == 1 :
440
- # Auto record if the output result is a single message
441
- self .record_message (output_messages [0 ])
442
- else :
443
- logger .warning (
444
- "Multiple messages returned in `step()`, message won't be "
445
- "recorded automatically. Please call `record_message()` to "
446
- "record the selected message manually."
634
+ return ChatAgentResponse (
635
+ msgs = output_messages , terminated = self .terminated , info = info
447
636
)
448
637
449
- return ChatAgentResponse (
450
- msgs = output_messages , terminated = self .terminated , info = info
451
- )
452
-
453
638
async def step_async (
454
639
self ,
455
640
input_message : BaseMessage ,
0 commit comments