Skip to content

Commit 1b6856d

Browse files
authored
feat: make system_message as optional (#1038)
1 parent 0b6734e commit 1b6856d

File tree

4 files changed

+127
-51
lines changed

4 files changed

+127
-51
lines changed

camel/agents/chat_agent.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ class ChatAgent(BaseAgent):
115115
r"""Class for managing conversations of CAMEL Chat Agents.
116116
117117
Args:
118-
system_message (BaseMessage): The system message for the chat agent.
118+
system_message (BaseMessage, optional): The system message for the
119+
chat agent.
119120
model (BaseModelBackend, optional): The model backend to use for
120121
generating responses. (default: :obj:`OpenAIModel` with
121122
`GPT_4O_MINI`)
@@ -144,7 +145,7 @@ class ChatAgent(BaseAgent):
144145

145146
def __init__(
146147
self,
147-
system_message: BaseMessage,
148+
system_message: Optional[BaseMessage] = None,
148149
model: Optional[BaseModelBackend] = None,
149150
memory: Optional[AgentMemory] = None,
150151
message_window_size: Optional[int] = None,
@@ -154,10 +155,14 @@ def __init__(
154155
external_tools: Optional[List[FunctionTool]] = None,
155156
response_terminators: Optional[List[ResponseTerminator]] = None,
156157
) -> None:
157-
self.orig_sys_message: BaseMessage = system_message
158-
self.system_message = system_message
159-
self.role_name: str = system_message.role_name
160-
self.role_type: RoleType = system_message.role_type
158+
self.orig_sys_message: Optional[BaseMessage] = system_message
159+
self._system_message: Optional[BaseMessage] = system_message
160+
self.role_name: str = (
161+
getattr(system_message, 'role_name', None) or "assistant"
162+
)
163+
self.role_type: RoleType = (
164+
getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
165+
)
161166
self.model_backend: BaseModelBackend = (
162167
model
163168
if model is not None
@@ -272,11 +277,12 @@ def reset(self):
272277
terminator.reset()
273278

274279
@property
275-
def system_message(self) -> BaseMessage:
280+
def system_message(self) -> Optional[BaseMessage]:
276281
r"""The getter method for the property :obj:`system_message`.
277282
278283
Returns:
279-
BaseMessage: The system message of this agent.
284+
Optional[BaseMessage]: The system message of this agent if set,
285+
else :obj:`None`.
280286
"""
281287
return self._system_message
282288

@@ -327,12 +333,22 @@ def set_output_language(self, output_language: str) -> BaseMessage:
327333
BaseMessage: The updated system message object.
328334
"""
329335
self.output_language = output_language
330-
content = self.orig_sys_message.content + (
336+
language_prompt = (
331337
"\nRegardless of the input language, "
332338
f"you must output text in {output_language}."
333339
)
334-
self.system_message = self.system_message.create_new_instance(content)
335-
return self.system_message
340+
if self.orig_sys_message is not None:
341+
content = self.orig_sys_message.content + language_prompt
342+
self._system_message = self.orig_sys_message.create_new_instance(
343+
content
344+
)
345+
return self._system_message
346+
else:
347+
self._system_message = BaseMessage.make_assistant_message(
348+
role_name="Assistant",
349+
content=language_prompt,
350+
)
351+
return self._system_message
336352

337353
def get_info(
338354
self,
@@ -377,12 +393,15 @@ def init_messages(self) -> None:
377393
r"""Initializes the stored messages list with the initial system
378394
message.
379395
"""
380-
system_record = MemoryRecord(
381-
message=self.system_message,
382-
role_at_backend=OpenAIBackendRole.SYSTEM,
383-
)
384-
self.memory.clear()
385-
self.memory.write_record(system_record)
396+
if self.orig_sys_message is not None:
397+
system_record = MemoryRecord(
398+
message=self.orig_sys_message,
399+
role_at_backend=OpenAIBackendRole.SYSTEM,
400+
)
401+
self.memory.clear()
402+
self.memory.write_record(system_record)
403+
else:
404+
self.memory.clear()
386405

387406
def record_message(self, message: BaseMessage) -> None:
388407
r"""Records the externally provided message into the agent memory as if

camel/societies/babyagi_playing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
)
107107

108108
self.assistant_agent: ChatAgent
109-
self.assistant_sys_msg: BaseMessage
109+
self.assistant_sys_msg: Optional[BaseMessage]
110110
self.task_creation_agent: TaskCreationAgent
111111
self.task_prioritization_agent: TaskPrioritizationAgent
112112
self.init_agents(
@@ -202,7 +202,8 @@ def init_agents(
202202

203203
self.task_creation_agent = TaskCreationAgent(
204204
objective=self.specified_task_prompt,
205-
role_name=self.assistant_sys_msg.role_name,
205+
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
206+
or "assistant",
206207
output_language=output_language,
207208
message_window_size=message_window_size,
208209
**(task_creation_agent_kwargs or {}),
@@ -238,7 +239,9 @@ def step(self) -> ChatAgentResponse:
238239

239240
task_name = self.subtasks.popleft()
240241
assistant_msg_msg = BaseMessage.make_user_message(
241-
role_name=self.assistant_sys_msg.role_name, content=f"{task_name}"
242+
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
243+
or "assistant",
244+
content=f"{task_name}",
242245
)
243246

244247
assistant_response = self.assistant_agent.step(assistant_msg_msg)

camel/societies/role_playing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def __init__(
149149

150150
self.assistant_agent: ChatAgent
151151
self.user_agent: ChatAgent
152-
self.assistant_sys_msg: BaseMessage
153-
self.user_sys_msg: BaseMessage
152+
self.assistant_sys_msg: Optional[BaseMessage]
153+
self.user_sys_msg: Optional[BaseMessage]
154154
self._init_agents(
155155
init_assistant_sys_msg,
156156
init_user_sys_msg,
@@ -454,9 +454,11 @@ def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage:
454454
)
455455
if init_msg_content is None:
456456
init_msg_content = default_init_msg_content
457+
457458
# Initialize a message sent by the assistant
458459
init_msg = BaseMessage.make_assistant_message(
459-
role_name=self.assistant_sys_msg.role_name,
460+
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
461+
or "assistant",
460462
content=init_msg_content,
461463
)
462464

test/agents/test_chat_agent.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,55 +69,77 @@ def test_chat_agent(model):
6969
dict(assistant_role="doctor"),
7070
role_tuple=("doctor", RoleType.ASSISTANT),
7171
)
72-
assistant = ChatAgent(system_msg, model=model)
72+
assistant_with_sys_msg = ChatAgent(system_msg, model=model)
73+
assistant_without_sys_msg = ChatAgent(model=model)
7374

74-
assert str(assistant) == (
75+
assert str(assistant_with_sys_msg) == (
7576
"ChatAgent(doctor, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})"
7677
)
78+
assert str(assistant_without_sys_msg) == (
79+
"ChatAgent(assistant, " f"RoleType.ASSISTANT, {ModelType.GPT_4O_MINI})"
80+
)
81+
82+
for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
83+
assistant.reset()
7784

78-
assistant.reset()
7985
user_msg = BaseMessage(
8086
role_name="Patient",
8187
role_type=RoleType.USER,
8288
meta_dict=dict(),
8389
content="Hello!",
8490
)
85-
assistant_response = assistant.step(user_msg)
8691

87-
assert isinstance(assistant_response.msgs, list)
88-
assert len(assistant_response.msgs) > 0
89-
assert isinstance(assistant_response.terminated, bool)
90-
assert assistant_response.terminated is False
91-
assert isinstance(assistant_response.info, dict)
92-
assert assistant_response.info['id'] is not None
92+
for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
93+
response = assistant.step(user_msg)
94+
assert isinstance(response.msgs, list)
95+
assert len(response.msgs) > 0
96+
assert isinstance(response.terminated, bool)
97+
assert response.terminated is False
98+
assert isinstance(response.info, dict)
99+
assert response.info['id'] is not None
93100

94101

102+
@pytest.mark.model_backend
95103
def test_chat_agent_stored_messages():
96104
system_msg = BaseMessage(
97105
role_name="assistant",
98106
role_type=RoleType.ASSISTANT,
99107
meta_dict=None,
100108
content="You are a help assistant.",
101109
)
102-
assistant = ChatAgent(system_msg)
110+
111+
assistant_with_sys_msg = ChatAgent(system_msg)
112+
assistant_without_sys_msg = ChatAgent()
103113

104114
expected_context = [system_msg.to_openai_system_message()]
105-
context, _ = assistant.memory.get_context()
106-
assert context == expected_context
115+
116+
context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context()
117+
assert context_with_sys_msg == expected_context
118+
context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context()
119+
assert context_without_sys_msg == []
107120

108121
user_msg = BaseMessage(
109122
role_name="User",
110123
role_type=RoleType.USER,
111124
meta_dict=dict(),
112125
content="Tell me a joke.",
113126
)
114-
assistant.update_memory(user_msg, OpenAIBackendRole.USER)
115-
expected_context = [
127+
128+
for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]:
129+
assistant.update_memory(user_msg, OpenAIBackendRole.USER)
130+
131+
expected_context_with_sys_msg = [
116132
system_msg.to_openai_system_message(),
117133
user_msg.to_openai_user_message(),
118134
]
119-
context, _ = assistant.memory.get_context()
120-
assert context == expected_context
135+
expected_context_without_sys_msg = [
136+
user_msg.to_openai_user_message(),
137+
]
138+
139+
context_with_sys_msg, _ = assistant_with_sys_msg.memory.get_context()
140+
assert context_with_sys_msg == expected_context_with_sys_msg
141+
context_without_sys_msg, _ = assistant_without_sys_msg.memory.get_context()
142+
assert context_without_sys_msg == expected_context_without_sys_msg
121143

122144

123145
@pytest.mark.model_backend
@@ -273,17 +295,27 @@ def test_chat_agent_multiple_return_messages(n):
273295
meta_dict=None,
274296
content="You are a helpful assistant.",
275297
)
276-
assistant = ChatAgent(system_msg, model=model)
277-
assistant.reset()
298+
assistant_with_sys_msg = ChatAgent(system_msg, model=model)
299+
assistant_without_sys_msg = ChatAgent(model=model)
300+
301+
assistant_with_sys_msg.reset()
302+
assistant_without_sys_msg.reset()
303+
278304
user_msg = BaseMessage(
279305
role_name="User",
280306
role_type=RoleType.USER,
281307
meta_dict=dict(),
282308
content="Tell me a joke.",
283309
)
284-
assistant_response = assistant.step(user_msg)
285-
assert assistant_response.msgs is not None
286-
assert len(assistant_response.msgs) == n
310+
assistant_with_sys_msg_response = assistant_with_sys_msg.step(user_msg)
311+
assistant_without_sys_msg_response = assistant_without_sys_msg.step(
312+
user_msg
313+
)
314+
315+
assert assistant_with_sys_msg_response.msgs is not None
316+
assert len(assistant_with_sys_msg_response.msgs) == n
317+
assert assistant_without_sys_msg_response.msgs is not None
318+
assert len(assistant_without_sys_msg_response.msgs) == n
287319

288320

289321
@pytest.mark.model_backend
@@ -396,21 +428,41 @@ def test_set_multiple_output_language():
396428
meta_dict=None,
397429
content="You are a help assistant.",
398430
)
399-
agent = ChatAgent(system_message=system_message)
431+
agent_with_sys_msg = ChatAgent(system_message=system_message)
432+
agent_without_sys_msg = ChatAgent()
400433

401434
# Verify that the length of the system message is kept constant even when
402435
# multiple set_output_language operations are called
403-
agent.set_output_language("Chinese")
404-
agent.set_output_language("English")
405-
agent.set_output_language("French")
406-
updated_system_message = BaseMessage(
436+
agent_with_sys_msg.set_output_language("Chinese")
437+
agent_with_sys_msg.set_output_language("English")
438+
agent_with_sys_msg.set_output_language("French")
439+
agent_without_sys_msg.set_output_language("Chinese")
440+
agent_without_sys_msg.set_output_language("English")
441+
agent_without_sys_msg.set_output_language("French")
442+
443+
updated_system_message_with_content = BaseMessage(
407444
role_name="assistant",
408445
role_type=RoleType.ASSISTANT,
409446
meta_dict=None,
410447
content="You are a help assistant."
411448
"\nRegardless of the input language, you must output text in French.",
412449
)
413-
assert agent.system_message.content == updated_system_message.content
450+
updated_system_message_without_content = BaseMessage(
451+
role_name="assistant",
452+
role_type=RoleType.ASSISTANT,
453+
meta_dict=None,
454+
content="\nRegardless of the input language, you must output text "
455+
"in French.",
456+
)
457+
458+
assert (
459+
agent_with_sys_msg.system_message.content
460+
== updated_system_message_with_content.content
461+
)
462+
assert (
463+
agent_without_sys_msg.system_message.content
464+
== updated_system_message_without_content.content
465+
)
414466

415467

416468
@pytest.mark.model_backend

0 commit comments

Comments
 (0)