6
6
from unittest .mock import MagicMock , patch
7
7
8
8
import pytest # pylint: disable=unused-import
9
+ from langchain .schema import HumanMessage , SystemMessage
9
10
10
11
from models .examples .prompt import hsr as prompt_hrs
11
12
from models .examples .rag import hsr as rag_hsr
14
15
from models .prompt_templates import NetecPromptTemplates
15
16
16
17
17
- HUMAN_PROMPT = 'return the word "SUCCESS" in upper case.'
18
+ HUMAN_MESSAGE = 'return the word "SUCCESS" in upper case.'
18
19
19
20
20
21
class TestExamples :
@@ -25,44 +26,47 @@ def test_prompt(self, mock_parse_args):
25
26
"""Test prompt example."""
26
27
mock_args = MagicMock ()
27
28
mock_args .system_prompt = "you are a helpful assistant"
28
- mock_args .human_prompt = HUMAN_PROMPT
29
+ mock_args .human_prompt = HUMAN_MESSAGE
29
30
mock_parse_args .return_value = mock_args
30
31
31
- result = prompt_hrs .cached_chat_request (mock_args .system_prompt , mock_args .human_prompt )
32
- assert result == "SUCCESS"
32
+ system_message = SystemMessage (content = "you are a helpful assistant" )
33
+ human_message = HumanMessage (content = HUMAN_MESSAGE )
34
+ result = prompt_hrs .cached_chat_request (system_message = system_message , human_message = human_message )
35
+ assert result .content == "SUCCESS"
33
36
34
37
@patch ("argparse.ArgumentParser.parse_args" )
35
38
def test_rag (self , mock_parse_args ):
36
39
"""Test RAG example."""
37
40
mock_args = MagicMock ()
38
- mock_args .human_prompt = HUMAN_PROMPT
41
+ mock_args .human_message = HUMAN_MESSAGE
39
42
mock_parse_args .return_value = mock_args
40
43
41
- result = rag_hsr .rag (mock_args .human_prompt )
44
+ human_message = HumanMessage (content = mock_args .human_message )
45
+ result = rag_hsr .rag (human_message = human_message )
42
46
assert result == "SUCCESS"
43
47
44
48
@patch ("argparse.ArgumentParser.parse_args" )
45
49
def test_training_services (self , mock_parse_args ):
46
50
"""Test training services templates."""
47
51
mock_args = MagicMock ()
48
- mock_args .human_prompt = HUMAN_PROMPT
52
+ mock_args .human_message = HUMAN_MESSAGE
49
53
mock_parse_args .return_value = mock_args
50
54
51
55
templates = NetecPromptTemplates ()
52
56
prompt = templates .training_services
53
57
54
- result = training_services_hsr .prompt_with_template (prompt = prompt , concept = mock_args .human_prompt )
58
+ result = training_services_hsr .prompt_with_template (prompt = prompt , concept = mock_args .human_message )
55
59
assert "SUCCESS" in result
56
60
57
61
@patch ("argparse.ArgumentParser.parse_args" )
58
62
def test_oracle_training_services (self , mock_parse_args ):
59
63
"""Test oracle training services."""
60
64
mock_args = MagicMock ()
61
- mock_args .human_prompt = HUMAN_PROMPT
65
+ mock_args .human_message = HUMAN_MESSAGE
62
66
mock_parse_args .return_value = mock_args
63
67
64
68
templates = NetecPromptTemplates ()
65
69
prompt = templates .oracle_training_services
66
70
67
- result = training_services_oracle_hsr .prompt_with_template (prompt = prompt , concept = mock_args .human_prompt )
71
+ result = training_services_oracle_hsr .prompt_with_template (prompt = prompt , concept = mock_args .human_message )
68
72
assert "SUCCESS" in result
0 commit comments