Skip to content

Commit 0139c0f

Browse files
authored
Merge pull request #7 from lpm0073/next
add more unit tests
2 parents fabf332 + 62a02b6 commit 0139c0f

File tree

5 files changed

+81
-11
lines changed

5 files changed

+81
-11
lines changed

models/ssm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import os
1111
from typing import List # ClassVar
1212

13+
# pinecone integration
1314
import pinecone
1415
from langchain.cache import InMemoryCache
1516

@@ -64,7 +65,7 @@ class SalesSupportModel:
6465
chunk_overlap=0,
6566
)
6667
openai_embedding = OpenAIEmbeddings()
67-
pinecone_search = Pinecone.from_existing_index(
68+
pinecone_index = Pinecone.from_existing_index(
6869
Credentials.PINECONE_INDEX_NAME,
6970
embedding=openai_embedding,
7071
)
@@ -76,14 +77,16 @@ def cached_chat_request(self, system_message: str, human_message: str) -> System
7677
HumanMessage(content=human_message),
7778
]
7879
# pylint: disable=not-callable
79-
return self.chat(messages)
80+
retval = self.chat(messages).content
81+
return retval
8082

8183
def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str = DEFAULT_MODEL_NAME) -> str:
8284
"""Prompt with template."""
8385
llm = OpenAI(model=model)
8486
retval = llm(prompt.format(concept=concept))
8587
return retval
8688

89+
# FIX NOTE: DEPRECATED
8790
def split_text(self, text: str) -> List[Document]:
8891
"""Split text."""
8992
text_splitter = RecursiveCharacterTextSplitter(
@@ -115,7 +118,7 @@ def load(self, filepath: str):
115118
k += 1
116119
print(k * "-", end="\r")
117120
texts_splitter_results = self.text_splitter.create_documents([doc.page_content])
118-
self.pinecone_search.from_existing_index(
121+
self.pinecone_index.from_existing_index(
119122
index_name=Credentials.PINECONE_INDEX_NAME,
120123
embedding=self.openai_embedding,
121124
text_key=texts_splitter_results,
@@ -137,7 +140,7 @@ def format_docs(docs):
137140
"""Format docs."""
138141
return "\n\n".join(doc.page_content for doc in docs)
139142

140-
retriever = self.pinecone_search.as_retriever()
143+
retriever = self.pinecone_index.as_retriever()
141144

142145
# Use the retriever to get relevant documents
143146
documents = retriever.get_relevant_documents(query=prompt)

models/tests/test_prompt_templates.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
# flake8: noqa: F401
3+
# pylint: disable=too-few-public-methods
4+
"""
5+
Test integrity of base class.
6+
"""
7+
import pytest # pylint: disable=unused-import
8+
from langchain.prompts import PromptTemplate
9+
10+
from ..prompt_templates import NetecPromptTemplates
11+
12+
13+
class TestPromptTemplates:
14+
"""Test SalesSupportModel class."""
15+
16+
def test_01_prompt_with_template(self):
17+
"""Ensure that all properties of the template class are PromptTemplate instances."""
18+
templates = NetecPromptTemplates()
19+
for prop_name in templates.get_properties():
20+
prop = getattr(templates, prop_name)
21+
assert isinstance(prop, PromptTemplate)

models/tests/test_base.py renamed to models/tests/test_prompts.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@
1010
from ..ssm import SalesSupportModel
1111

1212

13-
class TestSalesSupportModel:
13+
class TestPrompts:
1414
"""Test SalesSupportModel class."""
1515

1616
ssm = SalesSupportModel()
1717
templates = NetecPromptTemplates()
1818

19-
def test_01_basic(self):
20-
"""Test a basic request"""
21-
22-
SalesSupportModel()
23-
2419
def test_oracle_training_services(self):
2520
"""Test a prompt with the Oracle training services template"""
2621

models/tests/test_ssm.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
# flake8: noqa: F401
3+
# pylint: disable=too-few-public-methods
4+
"""
5+
Test integrity of base class.
6+
"""
7+
import pinecone
8+
import pytest # pylint: disable=unused-import
9+
from langchain.chat_models import ChatOpenAI
10+
from langchain.embeddings import OpenAIEmbeddings
11+
from langchain.text_splitter import RecursiveCharacterTextSplitter
12+
from langchain.vectorstores.pinecone import Pinecone
13+
14+
from ..const import Credentials
15+
from ..ssm import SalesSupportModel
16+
17+
18+
class TestSalesSupportModel:
19+
"""Test SalesSupportModel class."""
20+
21+
def test_01_basic(self):
22+
"""Ensure that we can instantiate the class."""
23+
24+
SalesSupportModel()
25+
26+
def test_02_class_aatribute_types(self):
27+
"""ensure that class attributes are of the correct type"""
28+
29+
ssm = SalesSupportModel()
30+
assert isinstance(ssm.chat, ChatOpenAI)
31+
assert isinstance(ssm.pinecone_index, Pinecone)
32+
assert isinstance(ssm.text_splitter, RecursiveCharacterTextSplitter)
33+
assert isinstance(ssm.openai_embedding, OpenAIEmbeddings)
34+
35+
def test_03_test_openai_connectivity(self):
36+
"""Ensure that we have connectivity to OpenAI."""
37+
38+
ssm = SalesSupportModel()
39+
retval = ssm.cached_chat_request(
40+
"your are a helpful assistant", "please return the value 'CORRECT' in all upper case."
41+
)
42+
assert retval == "CORRECT"
43+
44+
def test_04_test_pinecone_connectivity(self):
45+
"""Ensure that we have connectivity to Pinecone."""
46+
# pylint: disable=broad-except
47+
try:
48+
pinecone.init(api_key=Credentials.PINECONE_API_KEY, environment=Credentials.PINECONE_ENVIRONMENT)
49+
except Exception as e:
50+
assert False, f"pinecone.init() failed with exception: {e}"

models/yt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
"""
33
LangChain Quickstart
44
~~~~~~~~~~~~~~~~~~~~
5+
LangChain Explained in 13 Minutes | QuickStart Tutorial for Beginners
56
67
see: https://www.youtube.com/watch?v=aywZrzNaKjs
7-
https://github.com/rabbitmetrics/langchain-13-min
8+
https://github.com/rabbitmetrics/langchain-13-min
89
"""
910
import os
1011

0 commit comments

Comments
 (0)