Skip to content

Commit b8e1f5c

Browse files
authored
feat: llama model tool calling (#965)
1 parent 6e94015 commit b8e1f5c

File tree

8 files changed

+392
-65
lines changed

8 files changed

+392
-65
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ body:
2727
attributes:
2828
label: What version of camel are you using?
2929
description: Run command `python3 -c 'print(__import__("camel").__version__)'` in your shell and paste the output here.
30-
placeholder: E.g., 0.2.1
30+
placeholder: E.g., 0.2.1a
3131
validations:
3232
required: true
3333

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ conda create --name camel python=3.10
119119
conda activate camel
120120
121121
# Clone github repo
122-
git clone -b v0.2.1 https://github.com/camel-ai/camel.git
122+
git clone -b v0.2.1a https://github.com/camel-ai/camel.git
123123
124124
# Change directory into project directory
125125
cd camel

camel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414

15-
__version__ = '0.2.1'
15+
__version__ = '0.2.1a'
1616

1717
__all__ = [
1818
'__version__',

camel/agents/chat_agent.py

Lines changed: 244 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import json
1717
import logging
18+
import re
19+
import uuid
1820
from collections import defaultdict
1921
from typing import (
2022
TYPE_CHECKING,
@@ -28,6 +30,7 @@
2830
)
2931

3032
from openai.types.chat import ChatCompletionMessageToolCall
33+
from openai.types.chat.chat_completion_message_tool_call import Function
3134
from pydantic import BaseModel
3235

3336
from camel.agents.base import BaseAgent
@@ -190,7 +193,7 @@ def __init__(
190193
tool.get_openai_tool_schema() for tool in all_tools
191194
]
192195
self.model_backend.model_config_dict['tools'] = tool_schema_list
193-
196+
self.tool_schema_list = tool_schema_list
194197
self.model_config_dict = self.model_backend.model_config_dict
195198

196199
self.model_token_limit = token_limit or self.model_backend.token_limit
@@ -206,6 +209,56 @@ def __init__(
206209
self.response_terminators = response_terminators or []
207210
self.init_messages()
208211

212+
# ruff: noqa: E501
213+
def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
214+
tool_prompts = []
215+
216+
for tool in tool_schema_list:
217+
tool_info = tool['function']
218+
tool_name = tool_info['name']
219+
tool_description = tool_info['description']
220+
tool_json = json.dumps(tool_info, indent=4)
221+
222+
prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
223+
tool_prompts.append(prompt)
224+
225+
tool_prompt_str = "\n".join(tool_prompts)
226+
227+
final_prompt = f'''
228+
# Tool prompt
229+
TOOL_PROMPT = f"""
230+
You have access to the following functions:
231+
232+
{tool_prompt_str}
233+
234+
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
235+
236+
<function=example_function_name>{{"example_name": "example_value"}}</function>
237+
238+
Reminder:
239+
- Function calls MUST follow the specified format, start with <function= and end with </function>
240+
- Required parameters MUST be specified
241+
- Only call one function at a time
242+
- Put the entire function call reply on one line
243+
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
244+
"""
245+
'''
246+
return final_prompt
247+
248+
def _parse_tool_response(self, response: str):
249+
function_regex = r"<function=(\w+)>(.*?)</function>"
250+
match = re.search(function_regex, response)
251+
252+
if match:
253+
function_name, args_string = match.groups()
254+
try:
255+
args = json.loads(args_string)
256+
return {"function": function_name, "arguments": args}
257+
except json.JSONDecodeError as error:
258+
print(f"Error parsing function arguments: {error}")
259+
return None
260+
return None
261+
209262
def reset(self):
210263
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
211264
stored messages.
@@ -367,89 +420,221 @@ def step(
367420
a boolean indicating whether the chat session has terminated,
368421
and information about the chat session.
369422
"""
370-
self.update_memory(input_message, OpenAIBackendRole.USER)
423+
if (
424+
isinstance(self.model_type, ModelType)
425+
and "lama" in self.model_type.value
426+
or isinstance(self.model_type, str)
427+
and "lama" in self.model_type
428+
):
429+
if self.model_backend.model_config_dict['tools']:
430+
tool_prompt = self._generate_tool_prompt(self.tool_schema_list)
431+
432+
tool_sys_msg = BaseMessage.make_assistant_message(
433+
role_name="Assistant",
434+
content=tool_prompt,
435+
)
371436

372-
tool_call_records: List[FunctionCallingRecord] = []
373-
while True:
374-
# Check if token has exceeded
375-
try:
376-
openai_messages, num_tokens = self.memory.get_context()
377-
except RuntimeError as e:
378-
return self._step_token_exceed(
379-
e.args[1], tool_call_records, "max_tokens_exceeded"
437+
self.update_memory(tool_sys_msg, OpenAIBackendRole.SYSTEM)
438+
439+
self.update_memory(input_message, OpenAIBackendRole.USER)
440+
441+
tool_call_records: List[FunctionCallingRecord] = []
442+
while True:
443+
# Check if token has exceeded
444+
try:
445+
openai_messages, num_tokens = self.memory.get_context()
446+
except RuntimeError as e:
447+
return self._step_token_exceed(
448+
e.args[1], tool_call_records, "max_tokens_exceeded"
449+
)
450+
451+
(
452+
response,
453+
output_messages,
454+
finish_reasons,
455+
usage_dict,
456+
response_id,
457+
) = self._step_model_response(openai_messages, num_tokens)
458+
# If the model response is not a function call, meaning the
459+
# model has generated a message response, break the loop
460+
if (
461+
not self.is_tools_added()
462+
or not isinstance(response, ChatCompletion)
463+
or "</function>" not in response.choices[0].message.content # type: ignore[operator]
464+
):
465+
break
466+
467+
parsed_content = self._parse_tool_response(
468+
response.choices[0].message.content # type: ignore[arg-type]
380469
)
381470

382-
(
383-
response,
471+
response.choices[0].message.tool_calls = [
472+
ChatCompletionMessageToolCall(
473+
id=str(uuid.uuid4()),
474+
function=Function(
475+
arguments=str(parsed_content["arguments"]).replace(
476+
"'", '"'
477+
),
478+
name=str(parsed_content["function"]),
479+
),
480+
type="function",
481+
)
482+
]
483+
484+
# Check for external tool call
485+
tool_call_request = response.choices[0].message.tool_calls[0]
486+
if tool_call_request.function.name in self.external_tool_names:
487+
# if model calls an external tool, directly return the
488+
# request
489+
info = self._step_get_info(
490+
output_messages,
491+
finish_reasons,
492+
usage_dict,
493+
response_id,
494+
tool_call_records,
495+
num_tokens,
496+
tool_call_request,
497+
)
498+
return ChatAgentResponse(
499+
msgs=output_messages,
500+
terminated=self.terminated,
501+
info=info,
502+
)
503+
504+
# Normal function calling
505+
tool_call_records.append(
506+
self._step_tool_call_and_update(response)
507+
)
508+
509+
if (
510+
output_schema is not None
511+
and self.model_type.supports_tool_calling
512+
):
513+
(
514+
output_messages,
515+
finish_reasons,
516+
usage_dict,
517+
response_id,
518+
tool_call,
519+
num_tokens,
520+
) = self._structure_output_with_function(output_schema)
521+
tool_call_records.append(tool_call)
522+
523+
info = self._step_get_info(
384524
output_messages,
385525
finish_reasons,
386526
usage_dict,
387527
response_id,
388-
) = self._step_model_response(openai_messages, num_tokens)
528+
tool_call_records,
529+
num_tokens,
530+
)
389531

390-
# If the model response is not a function call, meaning the model
391-
# has generated a message response, break the loop
392-
if (
393-
not self.is_tools_added()
394-
or not isinstance(response, ChatCompletion)
395-
or response.choices[0].message.tool_calls is None
396-
):
397-
break
532+
if len(output_messages) == 1:
533+
# Auto record if the output result is a single message
534+
self.record_message(output_messages[0])
535+
else:
536+
logger.warning(
537+
"Multiple messages returned in `step()`, message won't be "
538+
"recorded automatically. Please call `record_message()` "
539+
"to record the selected message manually."
540+
)
398541

399-
# Check for external tool call
400-
tool_call_request = response.choices[0].message.tool_calls[0]
401-
if tool_call_request.function.name in self.external_tool_names:
402-
# if model calls an external tool, directly return the request
403-
info = self._step_get_info(
542+
return ChatAgentResponse(
543+
msgs=output_messages, terminated=self.terminated, info=info
544+
)
545+
546+
else:
547+
self.update_memory(input_message, OpenAIBackendRole.USER)
548+
549+
tool_call_records: List[FunctionCallingRecord] = [] # type: ignore[no-redef]
550+
while True:
551+
# Check if token has exceeded
552+
try:
553+
openai_messages, num_tokens = self.memory.get_context()
554+
except RuntimeError as e:
555+
return self._step_token_exceed(
556+
e.args[1], tool_call_records, "max_tokens_exceeded"
557+
)
558+
559+
(
560+
response,
404561
output_messages,
405562
finish_reasons,
406563
usage_dict,
407564
response_id,
408-
tool_call_records,
409-
num_tokens,
410-
tool_call_request,
411-
)
412-
return ChatAgentResponse(
413-
msgs=output_messages, terminated=self.terminated, info=info
565+
) = self._step_model_response(openai_messages, num_tokens)
566+
# If the model response is not a function call, meaning the
567+
# model has generated a message response, break the loop
568+
if (
569+
not self.is_tools_added()
570+
or not isinstance(response, ChatCompletion)
571+
or response.choices[0].message.tool_calls is None
572+
):
573+
break
574+
575+
# Check for external tool call
576+
tool_call_request = response.choices[0].message.tool_calls[0]
577+
578+
if tool_call_request.function.name in self.external_tool_names:
579+
# if model calls an external tool, directly return the
580+
# request
581+
info = self._step_get_info(
582+
output_messages,
583+
finish_reasons,
584+
usage_dict,
585+
response_id,
586+
tool_call_records,
587+
num_tokens,
588+
tool_call_request,
589+
)
590+
return ChatAgentResponse(
591+
msgs=output_messages,
592+
terminated=self.terminated,
593+
info=info,
594+
)
595+
596+
# Normal function calling
597+
tool_call_records.append(
598+
self._step_tool_call_and_update(response)
414599
)
415600

416-
# Normal function calling
417-
tool_call_records.append(self._step_tool_call_and_update(response))
601+
if (
602+
output_schema is not None
603+
and self.model_type.supports_tool_calling
604+
):
605+
(
606+
output_messages,
607+
finish_reasons,
608+
usage_dict,
609+
response_id,
610+
tool_call,
611+
num_tokens,
612+
) = self._structure_output_with_function(output_schema)
613+
tool_call_records.append(tool_call)
418614

419-
if output_schema is not None and self.model_type.supports_tool_calling:
420-
(
615+
info = self._step_get_info(
421616
output_messages,
422617
finish_reasons,
423618
usage_dict,
424619
response_id,
425-
tool_call,
620+
tool_call_records,
426621
num_tokens,
427-
) = self._structure_output_with_function(output_schema)
428-
tool_call_records.append(tool_call)
622+
)
429623

430-
info = self._step_get_info(
431-
output_messages,
432-
finish_reasons,
433-
usage_dict,
434-
response_id,
435-
tool_call_records,
436-
num_tokens,
437-
)
624+
if len(output_messages) == 1:
625+
# Auto record if the output result is a single message
626+
self.record_message(output_messages[0])
627+
else:
628+
logger.warning(
629+
"Multiple messages returned in `step()`, message won't be "
630+
"recorded automatically. Please call `record_message()` "
631+
"to record the selected message manually."
632+
)
438633

439-
if len(output_messages) == 1:
440-
# Auto record if the output result is a single message
441-
self.record_message(output_messages[0])
442-
else:
443-
logger.warning(
444-
"Multiple messages returned in `step()`, message won't be "
445-
"recorded automatically. Please call `record_message()` to "
446-
"record the selected message manually."
634+
return ChatAgentResponse(
635+
msgs=output_messages, terminated=self.terminated, info=info
447636
)
448637

449-
return ChatAgentResponse(
450-
msgs=output_messages, terminated=self.terminated, info=info
451-
)
452-
453638
async def step_async(
454639
self,
455640
input_message: BaseMessage,

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
project = 'CAMEL'
2828
copyright = '2024, CAMEL-AI.org'
2929
author = 'CAMEL-AI.org'
30-
release = '0.2.1'
30+
release = '0.2.1a'
3131

3232
html_favicon = (
3333
'https://raw.githubusercontent.com/camel-ai/camel/master/misc/favicon.png'

0 commit comments

Comments
 (0)