Skip to content

Commit 20d7fc4

Browse files
committed
chore: add unit tests for command line prompts
1 parent 22c8a9b commit 20d7fc4

File tree

12 files changed

+99
-31
lines changed

12 files changed

+99
-31
lines changed

models/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) for the LangChain project."""
3+
"""Sales Support Model (hsr) for the LangChain project."""
44

55
import os
66

models/examples/load.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
2+
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(description="RAG example")
1212
parser.add_argument("filepath", type=str, help="Location of PDF documents")
1313
args = parser.parse_args()
1414

15-
ssm.load(filepath=args.filepath)
15+
hsr.load(filepath=args.filepath)

models/examples/prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM)"""
2+
"""Sales Support Model (hsr)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010

1111
if __name__ == "__main__":
12-
parser = argparse.ArgumentParser(description="SSM examples")
12+
parser = argparse.ArgumentParser(description="hsr examples")
1313
parser.add_argument("system_prompt", type=str, help="A system prompt to send to the model.")
1414
parser.add_argument("human_prompt", type=str, help="A human prompt to send to the model.")
1515
args = parser.parse_args()
1616

17-
result = ssm.cached_chat_request(args.system_prompt, args.human_prompt)
17+
result = hsr.cached_chat_request(args.system_prompt, args.human_prompt)
1818
print(result)

models/examples/rag.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
2+
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(description="RAG example")
1212
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
1313
args = parser.parse_args()
1414

15-
result = ssm.rag(prompt=args.prompt)
15+
result = hsr.rag(prompt=args.prompt)
1616
print(result)

models/examples/training_services.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) for the LangChain project."""
2+
"""Sales Support Model (hsr) for the LangChain project."""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66
from models.prompt_templates import NetecPromptTemplates
77

88

9-
ssm = HybridSearchRetriever()
9+
hsr = HybridSearchRetriever()
1010
templates = NetecPromptTemplates()
1111

1212
if __name__ == "__main__":
13-
parser = argparse.ArgumentParser(description="SSM examples")
13+
parser = argparse.ArgumentParser(description="hsr examples")
1414
parser.add_argument("concept", type=str, help="A kind of training that Netec provides.")
1515
args = parser.parse_args()
1616

1717
prompt = templates.training_services
18-
result = ssm.prompt_with_template(prompt=prompt, concept=args.concept)
18+
result = hsr.prompt_with_template(prompt=prompt, concept=args.concept)
1919
print(result)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) for the LangChain project."""
2+
"""Sales Support Model (hsr) for the LangChain project."""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66
from models.prompt_templates import NetecPromptTemplates
77

88

9-
ssm = HybridSearchRetriever()
9+
hsr = HybridSearchRetriever()
1010
templates = NetecPromptTemplates()
1111

1212
if __name__ == "__main__":
13-
parser = argparse.ArgumentParser(description="SSM Oracle examples")
13+
parser = argparse.ArgumentParser(description="hsr Oracle examples")
1414
parser.add_argument("concept", type=str, help="An Oracle certification exam prep")
1515
args = parser.parse_args()
1616

1717
prompt = templates.oracle_training_services
18-
result = ssm.prompt_with_template(prompt=prompt, concept=args.concept)
18+
result = hsr.prompt_with_template(prompt=prompt, concept=args.concept)
1919
print(result)

models/prompt_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) prompt templates"""
3+
"""Sales Support Model (hsr) prompt templates"""
44

55
from langchain.prompts import PromptTemplate
66

models/tests/test_examples.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -*- coding: utf-8 -*-
2+
# flake8: noqa: F401
3+
"""
4+
Test command line example prompts.
5+
"""
6+
from unittest.mock import MagicMock, patch
7+
8+
import pytest # pylint: disable=unused-import
9+
10+
from models.examples.prompt import hsr as prompt_hrs
11+
from models.examples.rag import hsr as rag_hsr
12+
from models.examples.training_services import hsr as training_services_hsr
13+
from models.examples.training_services_oracle import hsr as training_services_oracle_hsr
14+
from models.prompt_templates import NetecPromptTemplates
15+
16+
17+
HUMAN_PROMPT = 'return the word "SUCCESS" in upper case.'
18+
19+
20+
class TestExamples:
21+
"""Test command line examples."""
22+
23+
@patch("argparse.ArgumentParser.parse_args")
24+
def test_prompt(self, mock_parse_args):
25+
"""Test prompt example."""
26+
mock_args = MagicMock()
27+
mock_args.system_prompt = "you are a helpful assistant"
28+
mock_args.human_prompt = HUMAN_PROMPT
29+
mock_parse_args.return_value = mock_args
30+
31+
result = prompt_hrs.cached_chat_request(mock_args.system_prompt, mock_args.human_prompt)
32+
assert result == "SUCCESS"
33+
34+
@patch("argparse.ArgumentParser.parse_args")
35+
def test_rag(self, mock_parse_args):
36+
"""Test RAG example."""
37+
mock_args = MagicMock()
38+
mock_args.human_prompt = HUMAN_PROMPT
39+
mock_parse_args.return_value = mock_args
40+
41+
result = rag_hsr.rag(mock_args.human_prompt)
42+
assert result == "SUCCESS"
43+
44+
@patch("argparse.ArgumentParser.parse_args")
45+
def test_training_services(self, mock_parse_args):
46+
"""Test training services templates."""
47+
mock_args = MagicMock()
48+
mock_args.human_prompt = HUMAN_PROMPT
49+
mock_parse_args.return_value = mock_args
50+
51+
templates = NetecPromptTemplates()
52+
prompt = templates.training_services
53+
54+
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
55+
assert "SUCCESS" in result
56+
57+
@patch("argparse.ArgumentParser.parse_args")
58+
def test_oracle_training_services(self, mock_parse_args):
59+
"""Test oracle training services."""
60+
mock_args = MagicMock()
61+
mock_args.human_prompt = HUMAN_PROMPT
62+
mock_parse_args.return_value = mock_args
63+
64+
templates = NetecPromptTemplates()
65+
prompt = templates.oracle_training_services
66+
67+
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
68+
assert "SUCCESS" in result

models/tests/test_hsr.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def test_01_basic(self):
2626
def test_02_class_aatribute_types(self):
2727
"""ensure that class attributes are of the correct type"""
2828

29-
ssm = HybridSearchRetriever()
30-
assert isinstance(ssm.chat, ChatOpenAI)
31-
assert isinstance(ssm.pinecone_index, Index)
32-
assert isinstance(ssm.text_splitter, TextSplitter)
33-
assert isinstance(ssm.openai_embeddings, OpenAIEmbeddings)
29+
hsr = HybridSearchRetriever()
30+
assert isinstance(hsr.chat, ChatOpenAI)
31+
assert isinstance(hsr.pinecone_index, Index)
32+
assert isinstance(hsr.text_splitter, TextSplitter)
33+
assert isinstance(hsr.openai_embeddings, OpenAIEmbeddings)

models/tests/test_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class TestOpenAI:
1515
def test_03_test_openai_connectivity(self):
1616
"""Ensure that we have connectivity to OpenAI."""
1717

18-
ssm = HybridSearchRetriever()
19-
retval = ssm.cached_chat_request(
18+
hsr = HybridSearchRetriever()
19+
retval = hsr.cached_chat_request(
2020
"your are a helpful assistant", "please return the value 'CORRECT' in all upper case."
2121
)
2222
assert retval == "CORRECT"

0 commit comments

Comments
 (0)