Skip to content

Commit 8504fbe

Browse files
committed
refactor: use SystemMessage and HumanMessage everywhere
1 parent f9d6d6d commit 8504fbe

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

models/examples/prompt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""Sales Support Model (hsr)"""
33
import argparse
44

5+
from langchain.schema import HumanMessage, SystemMessage
6+
57
from models.hybrid_search_retreiver import HybridSearchRetriever
68

79

@@ -10,9 +12,11 @@
1012

1113
if __name__ == "__main__":
1214
parser = argparse.ArgumentParser(description="hsr examples")
13-
parser.add_argument("system_prompt", type=str, help="A system prompt to send to the model.")
14-
parser.add_argument("human_prompt", type=str, help="A human prompt to send to the model.")
15+
parser.add_argument("system_message", type=str, help="A system prompt to send to the model.")
16+
parser.add_argument("human_message", type=str, help="A human prompt to send to the model.")
1517
args = parser.parse_args()
1618

17-
result = hsr.cached_chat_request(args.system_prompt, args.human_prompt)
18-
print(result)
19+
system_message = SystemMessage(text=args.system_message)
20+
human_message = HumanMessage(text=args.human_message)
21+
result = hsr.cached_chat_request(system_message=system_message, human_message=human_message)
22+
print(result.content)

models/examples/rag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

5+
from langchain.schema import HumanMessage
6+
57
from models.hybrid_search_retreiver import HybridSearchRetriever
68

79

@@ -12,5 +14,6 @@
1214
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
1315
args = parser.parse_args()
1416

15-
result = hsr.rag(prompt=args.prompt)
17+
human_message = HumanMessage(text=args.prompt)
18+
result = hsr.rag(human_message=human_message)
1619
print(result)

models/tests/test_examples.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from unittest.mock import MagicMock, patch
77

88
import pytest # pylint: disable=unused-import
9+
from langchain.schema import HumanMessage, SystemMessage
910

1011
from models.examples.prompt import hsr as prompt_hrs
1112
from models.examples.rag import hsr as rag_hsr
@@ -14,7 +15,7 @@
1415
from models.prompt_templates import NetecPromptTemplates
1516

1617

17-
HUMAN_PROMPT = 'return the word "SUCCESS" in upper case.'
18+
HUMAN_MESSAGE = 'return the word "SUCCESS" in upper case.'
1819

1920

2021
class TestExamples:
@@ -25,44 +26,47 @@ def test_prompt(self, mock_parse_args):
2526
"""Test prompt example."""
2627
mock_args = MagicMock()
2728
mock_args.system_prompt = "you are a helpful assistant"
28-
mock_args.human_prompt = HUMAN_PROMPT
29+
mock_args.human_prompt = HUMAN_MESSAGE
2930
mock_parse_args.return_value = mock_args
3031

31-
result = prompt_hrs.cached_chat_request(mock_args.system_prompt, mock_args.human_prompt)
32-
assert result == "SUCCESS"
32+
system_message = SystemMessage(content="you are a helpful assistant")
33+
human_message = HumanMessage(content=HUMAN_MESSAGE)
34+
result = prompt_hrs.cached_chat_request(system_message=system_message, human_message=human_message)
35+
assert result.content == "SUCCESS"
3336

3437
@patch("argparse.ArgumentParser.parse_args")
3538
def test_rag(self, mock_parse_args):
3639
"""Test RAG example."""
3740
mock_args = MagicMock()
38-
mock_args.human_prompt = HUMAN_PROMPT
41+
mock_args.human_message = HUMAN_MESSAGE
3942
mock_parse_args.return_value = mock_args
4043

41-
result = rag_hsr.rag(mock_args.human_prompt)
44+
human_message = HumanMessage(content=mock_args.human_message)
45+
result = rag_hsr.rag(human_message=human_message)
4246
assert result == "SUCCESS"
4347

4448
@patch("argparse.ArgumentParser.parse_args")
4549
def test_training_services(self, mock_parse_args):
4650
"""Test training services templates."""
4751
mock_args = MagicMock()
48-
mock_args.human_prompt = HUMAN_PROMPT
52+
mock_args.human_message = HUMAN_MESSAGE
4953
mock_parse_args.return_value = mock_args
5054

5155
templates = NetecPromptTemplates()
5256
prompt = templates.training_services
5357

54-
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
58+
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
5559
assert "SUCCESS" in result
5660

5761
@patch("argparse.ArgumentParser.parse_args")
5862
def test_oracle_training_services(self, mock_parse_args):
5963
"""Test oracle training services."""
6064
mock_args = MagicMock()
61-
mock_args.human_prompt = HUMAN_PROMPT
65+
mock_args.human_message = HUMAN_MESSAGE
6266
mock_parse_args.return_value = mock_args
6367

6468
templates = NetecPromptTemplates()
6569
prompt = templates.oracle_training_services
6670

67-
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
71+
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
6872
assert "SUCCESS" in result

models/tests/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ def test_03_test_openai_connectivity(self):
1919
retval = hsr.cached_chat_request(
2020
"your are a helpful assistant", "please return the value 'CORRECT' in all upper case."
2121
)
22-
assert retval == "CORRECT"
22+
assert retval.content == "CORRECT"

0 commit comments

Comments
 (0)