Skip to content

Commit 8de793d

Browse files
committed
feat: perfect load(). revert rag() to openai only calls
1 parent 781cf06 commit 8de793d

File tree

5 files changed

+80
-54
lines changed

5 files changed

+80
-54
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ python3 -m models.examples.training_services "Microsoft certified Azure AI engin
3535
# example 4 - prompted assistant
3636
python3 -m models.examples.training_services_oracle "Oracle database administrator"
3737

38-
# example 5 - RAG
39-
python3 -m models.examples.rag "./data/" "What is Accounting Based Valuation?"
38+
# example 5 - Load PDF documents
39+
python3 -m models.examples.load "./data/"
40+
41+
# example 6 - Retrieval Augmented Generation
42+
python3 -m models.examples.rag "What is Accounting Based Valuation?"
4043
```
4144

4245
## Requirements

models/examples/load.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
3+
import argparse
4+
5+
from ..ssm import SalesSupportModel
6+
7+
8+
ssm = SalesSupportModel()
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser(description="RAG example")
12+
parser.add_argument("filepath", type=str, help="Location of PDF documents")
13+
args = parser.parse_args()
14+
15+
ssm.load(filepath=args.filepath)

models/examples/rag.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(description="RAG example")
12-
parser.add_argument("filepath", type=str, help="Location of PDF documents")
1312
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
1413
args = parser.parse_args()
1514

16-
result = ssm.rag(filepath=args.filepath, prompt=args.prompt)
15+
result = ssm.rag(prompt=args.prompt)
1716
print(result)

models/ssm.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
import glob
1010
import os
11-
from typing import ClassVar, List
11+
from typing import List # ClassVar
1212

1313
import pinecone
14-
from langchain import hub
1514
from langchain.cache import InMemoryCache
1615

1716
# prompting and chat
@@ -27,16 +26,17 @@
2726
from langchain.globals import set_llm_cache
2827
from langchain.llms.openai import OpenAI
2928
from langchain.prompts import PromptTemplate
30-
from langchain.schema import HumanMessage, StrOutputParser, SystemMessage
31-
from langchain.schema.runnable import RunnablePassthrough
29+
from langchain.schema import HumanMessage, SystemMessage
3230
from langchain.text_splitter import Document, RecursiveCharacterTextSplitter
3331
from langchain.vectorstores.pinecone import Pinecone
34-
from pydantic import BaseModel, ConfigDict, Field # ValidationError
3532

3633
# this project
3734
from models.const import Credentials
3835

3936

37+
# from pydantic import BaseModel, ConfigDict, Field
38+
39+
4040
###############################################################################
4141
# initializations
4242
###############################################################################
@@ -45,28 +45,23 @@
4545
set_llm_cache(InMemoryCache())
4646

4747

48-
class SalesSupportModel(BaseModel):
48+
class SalesSupportModel:
4949
"""Sales Support Model (SSM)."""
5050

51-
Config: ClassVar = ConfigDict(arbitrary_types_allowed=True)
52-
5351
# prompting wrapper
54-
chat: ChatOpenAI = Field(
55-
default_factory=lambda: ChatOpenAI(
56-
api_key=Credentials.OPENAI_API_KEY,
57-
organization=Credentials.OPENAI_API_ORGANIZATION,
58-
cache=True,
59-
max_retries=3,
60-
model="gpt-3.5-turbo",
61-
temperature=0.0,
62-
)
52+
chat = ChatOpenAI(
53+
api_key=Credentials.OPENAI_API_KEY,
54+
organization=Credentials.OPENAI_API_ORGANIZATION,
55+
cache=True,
56+
max_retries=3,
57+
model="gpt-3.5-turbo",
58+
temperature=0.0,
6359
)
6460

6561
# embeddings
66-
texts_splitter_results: List[Document] = Field(None, description="Text splitter results")
67-
pinecone_search: Pinecone = Field(None, description="Pinecone search")
68-
openai_embedding: OpenAIEmbeddings = Field(OpenAIEmbeddings())
69-
query_result: List[float] = Field(None, description="Vector database query result")
62+
texts_splitter_results: List[Document]
63+
openai_embedding = OpenAIEmbeddings()
64+
query_result: List[float]
7065

7166
def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
7267
"""Cached chat request."""
@@ -103,13 +98,13 @@ def embed(self, text: str) -> List[float]:
10398
# pylint: disable=no-member
10499
self.openai_embedding.embed_query(embedding)
105100

106-
self.pinecone_search = Pinecone.from_documents(
107-
texts_splitter_results,
101+
Pinecone.from_documents(
102+
documents=texts_splitter_results,
108103
embedding=self.openai_embedding,
109104
index_name=Credentials.PINECONE_INDEX_NAME,
110105
)
111106

112-
def rag(self, filepath: str, prompt: str):
107+
def load(self, filepath: str):
113108
"""
114109
Embed PDF.
115110
1. Load PDF document text data
@@ -118,39 +113,52 @@ def rag(self, filepath: str, prompt: str):
118113
4. Store in Pinecone
119114
"""
120115

121-
# pylint: disable=unused-variable
122-
def format_docs(docs):
123-
"""Format docs."""
124-
return "\n\n".join(doc.page_content for doc in docs)
125-
126-
for pdf_file in glob.glob(os.path.join(filepath, "*.pdf")):
116+
pdf_files = glob.glob(os.path.join(filepath, "*.pdf"))
117+
i = 0
118+
for pdf_file in pdf_files:
119+
i += 1
120+
j = len(pdf_files)
121+
print(f"Loading PDF {i} of {j}: ")
127122
loader = PyPDFLoader(file_path=pdf_file)
128123
docs = loader.load()
124+
k = 0
129125
for doc in docs:
126+
k += 1
127+
print(k * "-", end="\r")
130128
self.embed(doc.page_content)
129+
print("Finished loading PDFs")
131130

132-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
133-
splits = text_splitter.split_documents(docs)
134-
vectorstore = Pinecone.from_documents(documents=splits, embedding=self.openai_embedding)
135-
retriever = vectorstore.as_retriever()
136-
prompt = hub.pull("rlm/rag-prompt")
137-
138-
rag_chain = (
139-
{"context": retriever | self.format_docs, "question": RunnablePassthrough()}
140-
| prompt
141-
| self.chat
142-
| StrOutputParser()
143-
)
144-
145-
return rag_chain.invoke(prompt)
146-
147-
def embedded_prompt(self, prompt: str) -> List[Document]:
131+
def rag(self, prompt: str):
148132
"""
149133
Embedded prompt.
150134
1. Retrieve prompt: Given a user input, relevant splits are retrieved
151135
from storage using a Retriever.
152136
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
153137
the question and the retrieved data
154138
"""
155-
result = self.pinecone_search.similarity_search(prompt)
156-
return result
139+
140+
# pylint: disable=unused-variable
141+
def format_docs(docs):
142+
"""Format docs."""
143+
return "\n\n".join(doc.page_content for doc in docs)
144+
145+
pinecone_search = Pinecone.from_existing_index(
146+
Credentials.PINECONE_INDEX_NAME,
147+
embedding=self.openai_embedding,
148+
)
149+
retriever = pinecone_search.as_retriever()
150+
151+
# Use the retriever to get relevant documents
152+
documents = retriever.get_relevant_documents(query=prompt)
153+
print(f"Retrieved {len(documents)} related documents from Pinecone")
154+
155+
# Generate a prompt from the retrieved documents
156+
prompt += " ".join(doc.page_content for doc in documents)
157+
print(f"Prompt contains {len(prompt.split())} words")
158+
print("Prompt:", prompt)
159+
print(doc for doc in documents)
160+
161+
# Get a response from the GPT-3.5-turbo model
162+
response = self.cached_chat_request(system_message="You are a helpful assistant.", human_message=prompt)
163+
164+
return response

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ codespell==2.2.6
1919

2020
# production
2121
# ------------
22-
python-dotenv==1.0.0
23-
pydantic==2.5.2
2422
langchain==0.0.343
23+
langchainhub==0.1.14
2524
openai==1.3.5
2625
pinecone-client==2.2.4
26+
pydantic==2.5.2
2727
pypdf==3.17.1
28+
python-dotenv==1.0.0
2829
tiktoken==0.5.1

0 commit comments

Comments
 (0)