@@ -69,55 +69,77 @@ def test_chat_agent(model):
6969 dict (assistant_role = "doctor" ),
7070 role_tuple = ("doctor" , RoleType .ASSISTANT ),
7171 )
72- assistant = ChatAgent (system_msg , model = model )
72+ assistant_with_sys_msg = ChatAgent (system_msg , model = model )
73+ assistant_without_sys_msg = ChatAgent (model = model )
7374
74- assert str (assistant ) == (
75+ assert str (assistant_with_sys_msg ) == (
7576 "ChatAgent(doctor, " f"RoleType.ASSISTANT, { ModelType .GPT_4O_MINI } )"
7677 )
78+ assert str (assistant_without_sys_msg ) == (
79+ "ChatAgent(assistant, " f"RoleType.ASSISTANT, { ModelType .GPT_4O_MINI } )"
80+ )
81+
82+ for assistant in [assistant_with_sys_msg , assistant_without_sys_msg ]:
83+ assistant .reset ()
7784
78- assistant .reset ()
7985 user_msg = BaseMessage (
8086 role_name = "Patient" ,
8187 role_type = RoleType .USER ,
8288 meta_dict = dict (),
8389 content = "Hello!" ,
8490 )
85- assistant_response = assistant .step (user_msg )
8691
87- assert isinstance (assistant_response .msgs , list )
88- assert len (assistant_response .msgs ) > 0
89- assert isinstance (assistant_response .terminated , bool )
90- assert assistant_response .terminated is False
91- assert isinstance (assistant_response .info , dict )
92- assert assistant_response .info ['id' ] is not None
92+ for assistant in [assistant_with_sys_msg , assistant_without_sys_msg ]:
93+ response = assistant .step (user_msg )
94+ assert isinstance (response .msgs , list )
95+ assert len (response .msgs ) > 0
96+ assert isinstance (response .terminated , bool )
97+ assert response .terminated is False
98+ assert isinstance (response .info , dict )
99+ assert response .info ['id' ] is not None
93100
94101
102+ @pytest .mark .model_backend
95103def test_chat_agent_stored_messages ():
96104 system_msg = BaseMessage (
97105 role_name = "assistant" ,
98106 role_type = RoleType .ASSISTANT ,
99107 meta_dict = None ,
100108 content = "You are a help assistant." ,
101109 )
102- assistant = ChatAgent (system_msg )
110+
111+ assistant_with_sys_msg = ChatAgent (system_msg )
112+ assistant_without_sys_msg = ChatAgent ()
103113
104114 expected_context = [system_msg .to_openai_system_message ()]
105- context , _ = assistant .memory .get_context ()
106- assert context == expected_context
115+
116+ context_with_sys_msg , _ = assistant_with_sys_msg .memory .get_context ()
117+ assert context_with_sys_msg == expected_context
118+ context_without_sys_msg , _ = assistant_without_sys_msg .memory .get_context ()
119+ assert context_without_sys_msg == []
107120
108121 user_msg = BaseMessage (
109122 role_name = "User" ,
110123 role_type = RoleType .USER ,
111124 meta_dict = dict (),
112125 content = "Tell me a joke." ,
113126 )
114- assistant .update_memory (user_msg , OpenAIBackendRole .USER )
115- expected_context = [
127+
128+ for assistant in [assistant_with_sys_msg , assistant_without_sys_msg ]:
129+ assistant .update_memory (user_msg , OpenAIBackendRole .USER )
130+
131+ expected_context_with_sys_msg = [
116132 system_msg .to_openai_system_message (),
117133 user_msg .to_openai_user_message (),
118134 ]
119- context , _ = assistant .memory .get_context ()
120- assert context == expected_context
135+ expected_context_without_sys_msg = [
136+ user_msg .to_openai_user_message (),
137+ ]
138+
139+ context_with_sys_msg , _ = assistant_with_sys_msg .memory .get_context ()
140+ assert context_with_sys_msg == expected_context_with_sys_msg
141+ context_without_sys_msg , _ = assistant_without_sys_msg .memory .get_context ()
142+ assert context_without_sys_msg == expected_context_without_sys_msg
121143
122144
123145@pytest .mark .model_backend
@@ -273,17 +295,27 @@ def test_chat_agent_multiple_return_messages(n):
273295 meta_dict = None ,
274296 content = "You are a helpful assistant." ,
275297 )
276- assistant = ChatAgent (system_msg , model = model )
277- assistant .reset ()
298+ assistant_with_sys_msg = ChatAgent (system_msg , model = model )
299+ assistant_without_sys_msg = ChatAgent (model = model )
300+
301+ assistant_with_sys_msg .reset ()
302+ assistant_without_sys_msg .reset ()
303+
278304 user_msg = BaseMessage (
279305 role_name = "User" ,
280306 role_type = RoleType .USER ,
281307 meta_dict = dict (),
282308 content = "Tell me a joke." ,
283309 )
284- assistant_response = assistant .step (user_msg )
285- assert assistant_response .msgs is not None
286- assert len (assistant_response .msgs ) == n
310+ assistant_with_sys_msg_response = assistant_with_sys_msg .step (user_msg )
311+ assistant_without_sys_msg_response = assistant_without_sys_msg .step (
312+ user_msg
313+ )
314+
315+ assert assistant_with_sys_msg_response .msgs is not None
316+ assert len (assistant_with_sys_msg_response .msgs ) == n
317+ assert assistant_without_sys_msg_response .msgs is not None
318+ assert len (assistant_without_sys_msg_response .msgs ) == n
287319
288320
289321@pytest .mark .model_backend
@@ -396,21 +428,41 @@ def test_set_multiple_output_language():
396428 meta_dict = None ,
397429 content = "You are a help assistant." ,
398430 )
399- agent = ChatAgent (system_message = system_message )
431+ agent_with_sys_msg = ChatAgent (system_message = system_message )
432+ agent_without_sys_msg = ChatAgent ()
400433
401434 # Verify that the length of the system message is kept constant even when
402435 # multiple set_output_language operations are called
403- agent .set_output_language ("Chinese" )
404- agent .set_output_language ("English" )
405- agent .set_output_language ("French" )
406- updated_system_message = BaseMessage (
436+ agent_with_sys_msg .set_output_language ("Chinese" )
437+ agent_with_sys_msg .set_output_language ("English" )
438+ agent_with_sys_msg .set_output_language ("French" )
439+ agent_without_sys_msg .set_output_language ("Chinese" )
440+ agent_without_sys_msg .set_output_language ("English" )
441+ agent_without_sys_msg .set_output_language ("French" )
442+
443+ updated_system_message_with_content = BaseMessage (
407444 role_name = "assistant" ,
408445 role_type = RoleType .ASSISTANT ,
409446 meta_dict = None ,
410447 content = "You are a help assistant."
411448 "\n Regardless of the input language, you must output text in French." ,
412449 )
413- assert agent .system_message .content == updated_system_message .content
450+ updated_system_message_without_content = BaseMessage (
451+ role_name = "assistant" ,
452+ role_type = RoleType .ASSISTANT ,
453+ meta_dict = None ,
454+ content = "\n Regardless of the input language, you must output text "
455+ "in French." ,
456+ )
457+
458+ assert (
459+ agent_with_sys_msg .system_message .content
460+ == updated_system_message_with_content .content
461+ )
462+ assert (
463+ agent_without_sys_msg .system_message .content
464+ == updated_system_message_without_content .content
465+ )
414466
415467
416468@pytest .mark .model_backend
0 commit comments