Skip to content

Commit e0f8fa5

Browse files
authored
Merge pull request #16 from lpm0073/next
bug fix to rag() and stronger typing
2 parents 80a0897 + 2aaf8b0 commit e0f8fa5

File tree

10 files changed

+64
-45
lines changed

10 files changed

+64
-45
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## [1.1.2](https://github.com/lpm0073/hybrid-search-retriever/compare/v1.1.1...v1.1.2) (2023-12-01)
22

3-
43
### Bug Fixes
54

6-
* syntax error in examples.prompt ([230b709](https://github.com/lpm0073/hybrid-search-retriever/commit/230b7090c96bdd4d7d8757b182f891ab1b82c6f4))
5+
- syntax error in examples.prompt ([230b709](https://github.com/lpm0073/hybrid-search-retriever/commit/230b7090c96bdd4d7d8757b182f891ab1b82c6f4))
76

87
## [1.1.1](https://github.com/lpm0073/netec-llm/compare/v1.1.0...v1.1.1) (2023-12-01)
98

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ python3 -m models.examples.training_services_oracle "Oracle database administrat
4646
python3 -m models.examples.load "./data/"
4747

4848
# example 6 - Retrieval Augmented Generation
49-
python3 -m models.examples.rag "What is Accounting Based Valuation?"
49+
python3 -m models.examples.rag "What analytics and accounting courses does Wharton offer?"
5050
```
5151

5252
## Setup

models/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# -*- coding: utf-8 -*-
2-
__version__ = "1.1.2"
2+
__version__ = "1.1.3"

models/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
OPENAI_CHAT_TEMPERATURE = float(os.environ.get("OPENAI_CHAT_TEMPERATURE", 0.0))
2222
OPENAI_CHAT_MAX_RETRIES = int(os.environ.get("OPENAI_CHAT_MAX_RETRIES", 3))
2323
OPENAI_CHAT_CACHE = bool(os.environ.get("OPENAI_CHAT_CACHE", True))
24-
DEBUG_MODE = bool(os.environ.get("DEBUG_MODE", False))
24+
DEBUG_MODE = os.environ.get("DEBUG_MODE", "False") == "True"
2525
else:
2626
raise FileNotFoundError("No .env file found in root directory of repository")
2727

models/examples/prompt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""Sales Support Model (hsr)"""
33
import argparse
44

5+
from langchain.schema import HumanMessage, SystemMessage
6+
57
from models.hybrid_search_retreiver import HybridSearchRetriever
68

79

@@ -10,9 +12,11 @@
1012

1113
if __name__ == "__main__":
1214
parser = argparse.ArgumentParser(description="hsr examples")
13-
parser.add_argument("system_prompt", type=str, help="A system prompt to send to the model.")
14-
parser.add_argument("human_prompt", type=str, help="A human prompt to send to the model.")
15+
parser.add_argument("system_message", type=str, help="A system prompt to send to the model.")
16+
parser.add_argument("human_message", type=str, help="A human prompt to send to the model.")
1517
args = parser.parse_args()
1618

17-
result = hsr.cached_chat_request(args.system_prompt, args.human_prompt)
18-
print(result)
19+
system_message = SystemMessage(text=args.system_message)
20+
human_message = HumanMessage(text=args.human_message)
21+
result = hsr.cached_chat_request(system_message=system_message, human_message=human_message)
22+
print(result.content)

models/examples/rag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

5+
from langchain.schema import HumanMessage
6+
57
from models.hybrid_search_retreiver import HybridSearchRetriever
68

79

@@ -12,5 +14,6 @@
1214
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
1315
args = parser.parse_args()
1416

15-
result = hsr.rag(prompt=args.prompt)
17+
human_message = HumanMessage(text=args.prompt)
18+
result = hsr.rag(human_message=human_message)
1619
print(result)

models/hybrid_search_retreiver.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818

1919
# document loading
2020
import glob
21+
22+
# general purpose imports
2123
import logging
2224
import os
2325
import textwrap
26+
from typing import Union
2427

2528
# pinecone integration
2629
import pinecone
@@ -38,7 +41,7 @@
3841

3942
# hybrid search capability
4043
from langchain.retrievers import PineconeHybridSearchRetriever
41-
from langchain.schema import HumanMessage, SystemMessage
44+
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
4245
from langchain.text_splitter import Document
4346
from langchain.vectorstores.pinecone import Pinecone
4447
from pinecone_text.sparse import BM25Encoder
@@ -95,14 +98,20 @@ class HybridSearchRetriever:
9598
text_splitter = TextSplitter()
9699
bm25_encoder = BM25Encoder().default()
97100

98-
def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
101+
def cached_chat_request(
102+
self, system_message: Union[str, SystemMessage], human_message: Union[str, HumanMessage]
103+
) -> BaseMessage:
99104
"""Cached chat request."""
100-
messages = [
101-
SystemMessage(content=system_message),
102-
HumanMessage(content=human_message),
103-
]
105+
if not isinstance(system_message, SystemMessage):
106+
logging.debug("Converting system message to SystemMessage")
107+
system_message = SystemMessage(content=str(system_message))
108+
109+
if not isinstance(human_message, HumanMessage):
110+
logging.debug("Converting human message to HumanMessage")
111+
human_message = HumanMessage(content=str(human_message))
112+
messages = [system_message, human_message]
104113
# pylint: disable=not-callable
105-
retval = self.chat(messages).content
114+
retval = self.chat(messages)
106115
return retval
107116

108117
def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str = DEFAULT_MODEL_NAME) -> str:
@@ -158,10 +167,10 @@ def load(self, filepath: str):
158167

159168
logging.debug("Finished loading PDFs")
160169

161-
def rag(self, prompt: str):
170+
def rag(self, human_message: Union[str, HumanMessage]):
162171
"""
163172
Embedded prompt.
164-
1. Retrieve prompt: Given a user input, relevant splits are retrieved
173+
1. Retrieve human message prompt: Given a user input, relevant splits are retrieved
165174
from storage using a Retriever.
166175
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
167176
the question and the retrieved data
@@ -174,33 +183,32 @@ def rag(self, prompt: str):
174183
The typical workflow is to use the embeddings to retrieve relevant documents,
175184
and then use the text of these documents as part of the prompt for GPT-3.
176185
"""
186+
if not isinstance(human_message, HumanMessage):
187+
logging.debug("Converting human_message to HumanMessage")
188+
human_message = HumanMessage(content=human_message)
189+
177190
retriever = PineconeHybridSearchRetriever(
178191
embeddings=self.openai_embeddings, sparse_encoder=self.bm25_encoder, index=self.pinecone_index
179192
)
180-
documents = retriever.get_relevant_documents(query=prompt)
193+
documents = retriever.get_relevant_documents(query=human_message.content)
181194
logging.debug("Retrieved %i related documents from Pinecone", len(documents))
182195

183196
# Extract the text from the documents
184197
document_texts = [doc.page_content for doc in documents]
185198
leader = textwrap.dedent(
186-
"""\
187-
\n\nYou can assume that the following is true.
199+
"""You are a helpful assistant.
200+
You can assume that all of the following is true.
188201
You should attempt to incorporate these facts
189-
into your response:\n\n
202+
into your responses:\n\n
190203
"""
191204
)
205+
system_message = f"{leader} {'. '.join(document_texts)}"
192206

193-
# Create a prompt that includes the document texts
194-
prompt_with_relevant_documents = f"{prompt + leader} {'. '.join(document_texts)}"
195-
196-
logging.debug("Prompt contains %i words", len(prompt_with_relevant_documents.split()))
197-
logging.debug("Prompt: %s", prompt_with_relevant_documents)
198-
199-
# Get a response from the GPT-3.5-turbo model
200-
response = self.cached_chat_request(
201-
system_message="You are a helpful assistant.", human_message=prompt_with_relevant_documents
202-
)
207+
logging.debug("System messages contains %i words", len(system_message.split()))
208+
logging.debug("Prompt: %s", system_message)
209+
system_message = SystemMessage(content=system_message)
210+
response = self.cached_chat_request(system_message=system_message, human_message=human_message)
203211

204212
logging.debug("Response:")
205213
logging.debug("------------------------------------------------------")
206-
return response
214+
return response.content

models/tests/test_examples.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from unittest.mock import MagicMock, patch
77

88
import pytest # pylint: disable=unused-import
9+
from langchain.schema import HumanMessage, SystemMessage
910

1011
from models.examples.prompt import hsr as prompt_hrs
1112
from models.examples.rag import hsr as rag_hsr
@@ -14,7 +15,7 @@
1415
from models.prompt_templates import NetecPromptTemplates
1516

1617

17-
HUMAN_PROMPT = 'return the word "SUCCESS" in upper case.'
18+
HUMAN_MESSAGE = 'return the word "SUCCESS" in upper case.'
1819

1920

2021
class TestExamples:
@@ -25,44 +26,47 @@ def test_prompt(self, mock_parse_args):
2526
"""Test prompt example."""
2627
mock_args = MagicMock()
2728
mock_args.system_prompt = "you are a helpful assistant"
28-
mock_args.human_prompt = HUMAN_PROMPT
29+
mock_args.human_prompt = HUMAN_MESSAGE
2930
mock_parse_args.return_value = mock_args
3031

31-
result = prompt_hrs.cached_chat_request(mock_args.system_prompt, mock_args.human_prompt)
32-
assert result == "SUCCESS"
32+
system_message = SystemMessage(content="you are a helpful assistant")
33+
human_message = HumanMessage(content=HUMAN_MESSAGE)
34+
result = prompt_hrs.cached_chat_request(system_message=system_message, human_message=human_message)
35+
assert result.content == "SUCCESS"
3336

3437
@patch("argparse.ArgumentParser.parse_args")
3538
def test_rag(self, mock_parse_args):
3639
"""Test RAG example."""
3740
mock_args = MagicMock()
38-
mock_args.human_prompt = HUMAN_PROMPT
41+
mock_args.human_message = HUMAN_MESSAGE
3942
mock_parse_args.return_value = mock_args
4043

41-
result = rag_hsr.rag(mock_args.human_prompt)
44+
human_message = HumanMessage(content=mock_args.human_message)
45+
result = rag_hsr.rag(human_message=human_message)
4246
assert result == "SUCCESS"
4347

4448
@patch("argparse.ArgumentParser.parse_args")
4549
def test_training_services(self, mock_parse_args):
4650
"""Test training services templates."""
4751
mock_args = MagicMock()
48-
mock_args.human_prompt = HUMAN_PROMPT
52+
mock_args.human_message = HUMAN_MESSAGE
4953
mock_parse_args.return_value = mock_args
5054

5155
templates = NetecPromptTemplates()
5256
prompt = templates.training_services
5357

54-
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
58+
result = training_services_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
5559
assert "SUCCESS" in result
5660

5761
@patch("argparse.ArgumentParser.parse_args")
5862
def test_oracle_training_services(self, mock_parse_args):
5963
"""Test oracle training services."""
6064
mock_args = MagicMock()
61-
mock_args.human_prompt = HUMAN_PROMPT
65+
mock_args.human_message = HUMAN_MESSAGE
6266
mock_parse_args.return_value = mock_args
6367

6468
templates = NetecPromptTemplates()
6569
prompt = templates.oracle_training_services
6670

67-
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_prompt)
71+
result = training_services_oracle_hsr.prompt_with_template(prompt=prompt, concept=mock_args.human_message)
6872
assert "SUCCESS" in result

models/tests/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ def test_03_test_openai_connectivity(self):
1919
retval = hsr.cached_chat_request(
2020
"your are a helpful assistant", "please return the value 'CORRECT' in all upper case."
2121
)
22-
assert retval == "CORRECT"
22+
assert retval.content == "CORRECT"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ codespell==2.2.6
2121
# ------------
2222
langchain==0.0.343
2323
langchainhub==0.1.14
24+
langchain-experimental==0.0.43
2425
openai==1.3.5
2526
pinecone-client==2.2.4
2627
pinecone-text==0.7.0

0 commit comments

Comments
 (0)