Skip to content

Commit 94479d9

Browse files
authored
Merge pull request #6 from lpm0073/next
add a dedicated rag() method
2 parents a3b1c17 + 62cd18f commit 94479d9

File tree

7 files changed

+161
-50
lines changed

7 files changed

+161
-50
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Netec Large Language Models
22

3-
A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept LLM to manage sales support inquiries on the Netec course catalogue.
3+
A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept Retrieval Augmented Generation (RAG) models using sales support PDF documents.
4+
5+
See:
6+
7+
- [LangChain RAG](https://python.langchain.com/docs/use_cases/question_answering/)
8+
- [LangChain Document Loaders](https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf)
9+
- [LanchChain Caching](https://python.langchain.com/docs/modules/model_io/llms/llm_caching)
410

511
## Installation
612

@@ -28,6 +34,12 @@ python3 -m models.examples.training_services "Microsoft certified Azure AI engin
2834

2935
# example 4 - prompted assistant
3036
python3 -m models.examples.training_services_oracle "Oracle database administrator"
37+
38+
# example 5 - Load PDF documents
39+
python3 -m models.examples.load "./data/"
40+
41+
# example 6 - Retrieval Augmented Generation
42+
python3 -m models.examples.rag "What is Accounting Based Valuation?"
3143
```
3244

3345
## Requirements

models/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# -*- coding: utf-8 -*-
2-
__version__ = "1.0.0"
2+
__version__ = "1.1.0"

models/examples/load.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
3+
import argparse
4+
5+
from ..ssm import SalesSupportModel
6+
7+
8+
ssm = SalesSupportModel()
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser(description="RAG example")
12+
parser.add_argument("filepath", type=str, help="Location of PDF documents")
13+
args = parser.parse_args()
14+
15+
ssm.load(filepath=args.filepath)

models/examples/rag.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# -*- coding: utf-8 -*-
2+
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
3+
import argparse
4+
5+
from ..ssm import SalesSupportModel
6+
7+
8+
ssm = SalesSupportModel()
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser(description="RAG example")
12+
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
13+
args = parser.parse_args()
14+
15+
result = ssm.rag(prompt=args.prompt)
16+
print(result)

models/prompt_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def oracle_training_services(self) -> PromptTemplate:
3535
template = (
3636
self.sales_role
3737
+ """
38-
Note that Netec is the exclusive provide of Oracle training services
38+
Note that Netec is the exclusive provider in Latin America of Oracle training services
3939
for the 6 levels of Oracle Certification credentials: Oracle Certified Junior Associate (OCJA),
4040
Oracle Certified Associate (OCA), Oracle Certified Professional (OCP),
4141
Oracle Certified Master (OCM), Oracle Certified Expert (OCE) and

models/ssm.py

Lines changed: 109 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,73 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) for the LangChain project."""
3+
"""
4+
Sales Support Model (SSM) for the LangChain project.
5+
See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
6+
https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
7+
"""
48

5-
from typing import ClassVar, List
9+
import glob
10+
import os
11+
from typing import List # ClassVar
612

713
import pinecone
14+
from langchain.cache import InMemoryCache
15+
16+
# prompting and chat
817
from langchain.chat_models import ChatOpenAI
18+
19+
# document loading
20+
from langchain.document_loaders import PyPDFLoader
21+
22+
# embedding
923
from langchain.embeddings import OpenAIEmbeddings
24+
25+
# vector database
26+
from langchain.globals import set_llm_cache
1027
from langchain.llms.openai import OpenAI
1128
from langchain.prompts import PromptTemplate
12-
from langchain.schema import HumanMessage, SystemMessage # AIMessage (not used)
29+
from langchain.schema import HumanMessage, SystemMessage
1330
from langchain.text_splitter import Document, RecursiveCharacterTextSplitter
1431
from langchain.vectorstores.pinecone import Pinecone
15-
from pydantic import BaseModel, ConfigDict, Field # ValidationError
1632

33+
# this project
1734
from models.const import Credentials
1835

1936

37+
# from pydantic import BaseModel, ConfigDict, Field
38+
39+
40+
###############################################################################
41+
# initializations
42+
###############################################################################
2043
DEFAULT_MODEL_NAME = "text-davinci-003"
2144
pinecone.init(api_key=Credentials.PINECONE_API_KEY, environment=Credentials.PINECONE_ENVIRONMENT)
45+
set_llm_cache(InMemoryCache())
2246

2347

24-
class SalesSupportModel(BaseModel):
48+
class SalesSupportModel:
2549
"""Sales Support Model (SSM)."""
2650

27-
Config: ClassVar = ConfigDict(arbitrary_types_allowed=True)
28-
2951
# prompting wrapper
30-
chat: ChatOpenAI = Field(
31-
default_factory=lambda: ChatOpenAI(
32-
api_key=Credentials.OPENAI_API_KEY,
33-
organization=Credentials.OPENAI_API_ORGANIZATION,
34-
max_retries=3,
35-
model="gpt-3.5-turbo",
36-
temperature=0.3,
37-
)
52+
chat = ChatOpenAI(
53+
api_key=Credentials.OPENAI_API_KEY,
54+
organization=Credentials.OPENAI_API_ORGANIZATION,
55+
cache=True,
56+
max_retries=3,
57+
model="gpt-3.5-turbo",
58+
temperature=0.0,
3859
)
3960

4061
# embeddings
41-
text_splitter: RecursiveCharacterTextSplitter = Field(
42-
default_factory=lambda: RecursiveCharacterTextSplitter(
43-
chunk_size=100,
44-
chunk_overlap=0,
45-
)
62+
text_splitter = RecursiveCharacterTextSplitter(
63+
chunk_size=100,
64+
chunk_overlap=0,
65+
)
66+
openai_embedding = OpenAIEmbeddings()
67+
pinecone_search = Pinecone.from_existing_index(
68+
Credentials.PINECONE_INDEX_NAME,
69+
embedding=openai_embedding,
4670
)
47-
48-
texts_splitter_results: List[Document] = Field(None, description="Text splitter results")
49-
pinecone_search: Pinecone = Field(None, description="Pinecone search")
50-
pinecone_index_name: str = Field(default="netec-ssm", description="Pinecone index name")
51-
openai_embedding: OpenAIEmbeddings = Field(default_factory=lambda: OpenAIEmbeddings(model="ada"))
52-
query_result: List[float] = Field(None, description="Vector database query result")
5371

5472
def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
5573
"""Cached chat request."""
@@ -68,24 +86,70 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
6886

6987
def split_text(self, text: str) -> List[Document]:
7088
"""Split text."""
71-
# pylint: disable=no-member
72-
retval = self.text_splitter.create_documents([text])
73-
return retval
74-
75-
def embed(self, text: str) -> List[float]:
76-
"""Embed."""
77-
texts_splitter_results = self.split_text(text)
78-
embedding = texts_splitter_results[0].page_content
79-
# pylint: disable=no-member
80-
self.openai_embedding.embed_query(embedding)
81-
82-
self.pinecone_search = Pinecone.from_documents(
83-
texts_splitter_results,
84-
embedding=self.openai_embedding,
85-
index_name=self.pinecone_index_name,
89+
text_splitter = RecursiveCharacterTextSplitter(
90+
chunk_size=100,
91+
chunk_overlap=0,
8692
)
93+
retval = text_splitter.create_documents([text])
94+
return retval
8795

88-
def embedded_prompt(self, prompt: str) -> List[Document]:
89-
"""Embedded prompt."""
90-
result = self.pinecone_search.similarity_search(prompt)
91-
return result
96+
def load(self, filepath: str):
97+
"""
98+
Embed PDF.
99+
1. Load PDF document text data
100+
2. Split into pages
101+
3. Embed each page
102+
4. Store in Pinecone
103+
"""
104+
105+
pdf_files = glob.glob(os.path.join(filepath, "*.pdf"))
106+
i = 0
107+
for pdf_file in pdf_files:
108+
i += 1
109+
j = len(pdf_files)
110+
print(f"Loading PDF {i} of {j}: ", pdf_file)
111+
loader = PyPDFLoader(file_path=pdf_file)
112+
docs = loader.load()
113+
k = 0
114+
for doc in docs:
115+
k += 1
116+
print(k * "-", end="\r")
117+
texts_splitter_results = self.text_splitter.create_documents([doc.page_content])
118+
self.pinecone_search.from_existing_index(
119+
index_name=Credentials.PINECONE_INDEX_NAME,
120+
embedding=self.openai_embedding,
121+
text_key=texts_splitter_results,
122+
)
123+
124+
print("Finished loading PDFs")
125+
126+
def rag(self, prompt: str):
127+
"""
128+
Embedded prompt.
129+
1. Retrieve prompt: Given a user input, relevant splits are retrieved
130+
from storage using a Retriever.
131+
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
132+
the question and the retrieved data
133+
"""
134+
135+
# pylint: disable=unused-variable
136+
def format_docs(docs):
137+
"""Format docs."""
138+
return "\n\n".join(doc.page_content for doc in docs)
139+
140+
retriever = self.pinecone_search.as_retriever()
141+
142+
# Use the retriever to get relevant documents
143+
documents = retriever.get_relevant_documents(query=prompt)
144+
print(f"Retrieved {len(documents)} related documents from Pinecone")
145+
146+
# Generate a prompt from the retrieved documents
147+
prompt += " ".join(doc.page_content for doc in documents)
148+
print(f"Prompt contains {len(prompt.split())} words")
149+
print("Prompt:", prompt)
150+
print(doc for doc in documents)
151+
152+
# Get a response from the GPT-3.5-turbo model
153+
response = self.cached_chat_request(system_message="You are a helpful assistant.", human_message=prompt)
154+
155+
return response

requirements.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ codespell==2.2.6
1919

2020
# production
2121
# ------------
22-
python-dotenv==1.0.0
23-
pydantic==2.5.2
2422
langchain==0.0.343
23+
langchainhub==0.1.14
24+
openai==1.3.5
2525
pinecone-client==2.2.4
26+
pydantic==2.5.2
27+
pypdf==3.17.1
28+
python-dotenv==1.0.0
29+
tiktoken==0.5.1

0 commit comments

Comments
 (0)