Skip to content

Commit 2335d22

Browse files
committed
feat: ssm.rag() w load, split, embed, store
1 parent a3b1c17 commit 2335d22

File tree

5 files changed

+114
-20
lines changed

5 files changed

+114
-20
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# Netec Large Language Models
22

3-
A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept LLM to manage sales support inquiries on the Netec course catalogue.
3+
A Python [LangChain](https://www.langchain.com/) - [Pinecone](https://docs.pinecone.io/docs/python-client) proof of concept Retrieval Augmented Generation (RAG) models using sales support PDF documents.
4+
5+
See:
6+
7+
- [LangChain RAG](https://python.langchain.com/docs/use_cases/question_answering/)
8+
- [LangChain Document Loaders](https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf)
9+
- [LanchChain Caching](https://python.langchain.com/docs/modules/model_io/llms/llm_caching)
410

511
## Installation
612

@@ -28,6 +34,9 @@ python3 -m models.examples.training_services "Microsoft certified Azure AI engin
2834

2935
# example 4 - prompted assistant
3036
python3 -m models.examples.training_services_oracle "Oracle database administrator"
37+
38+
# example 5 - RAG
39+
python3 -m models.examples.rag "./data/" "What is Accounting Based Valuation?"
3140
```
3241

3342
## Requirements

models/examples/rag.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# -*- coding: utf-8 -*-
2+
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
3+
import argparse
4+
5+
from ..ssm import SalesSupportModel
6+
7+
8+
ssm = SalesSupportModel()
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser(description="RAG example")
12+
parser.add_argument("filepath", type=str, help="Location of PDF documents")
13+
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
14+
args = parser.parse_args()
15+
16+
result = ssm.rag(filepath=args.filepath, prompt=args.prompt)
17+
print(result)

models/prompt_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def oracle_training_services(self) -> PromptTemplate:
3535
template = (
3636
self.sales_role
3737
+ """
38-
Note that Netec is the exclusive provide of Oracle training services
38+
Note that Netec is the exclusive provider in Latin America of Oracle training services
3939
for the 6 levels of Oracle Certification credentials: Oracle Certified Junior Associate (OCJA),
4040
Oracle Certified Associate (OCA), Oracle Certified Professional (OCP),
4141
Oracle Certified Master (OCM), Oracle Certified Expert (OCE) and

models/ssm.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) for the LangChain project."""
4-
3+
"""
4+
Sales Support Model (SSM) for the LangChain project.
5+
See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
6+
https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
7+
"""
8+
9+
import glob
10+
import os
511
from typing import ClassVar, List
612

713
import pinecone
14+
from langchain import hub
15+
from langchain.cache import InMemoryCache
16+
17+
# prompting and chat
818
from langchain.chat_models import ChatOpenAI
19+
20+
# document loading
21+
from langchain.document_loaders import PyPDFLoader
22+
23+
# embedding
924
from langchain.embeddings import OpenAIEmbeddings
25+
26+
# vector database
27+
from langchain.globals import set_llm_cache
1028
from langchain.llms.openai import OpenAI
1129
from langchain.prompts import PromptTemplate
12-
from langchain.schema import HumanMessage, SystemMessage # AIMessage (not used)
30+
from langchain.schema import HumanMessage, StrOutputParser, SystemMessage
31+
from langchain.schema.runnable import RunnablePassthrough
1332
from langchain.text_splitter import Document, RecursiveCharacterTextSplitter
1433
from langchain.vectorstores.pinecone import Pinecone
1534
from pydantic import BaseModel, ConfigDict, Field # ValidationError
1635

36+
# this project
1737
from models.const import Credentials
1838

1939

40+
###############################################################################
41+
# initializations
42+
###############################################################################
2043
DEFAULT_MODEL_NAME = "text-davinci-003"
2144
pinecone.init(api_key=Credentials.PINECONE_API_KEY, environment=Credentials.PINECONE_ENVIRONMENT)
45+
set_llm_cache(InMemoryCache())
2246

2347

2448
class SalesSupportModel(BaseModel):
@@ -31,24 +55,17 @@ class SalesSupportModel(BaseModel):
3155
default_factory=lambda: ChatOpenAI(
3256
api_key=Credentials.OPENAI_API_KEY,
3357
organization=Credentials.OPENAI_API_ORGANIZATION,
58+
cache=True,
3459
max_retries=3,
3560
model="gpt-3.5-turbo",
36-
temperature=0.3,
61+
temperature=0.0,
3762
)
3863
)
3964

4065
# embeddings
41-
text_splitter: RecursiveCharacterTextSplitter = Field(
42-
default_factory=lambda: RecursiveCharacterTextSplitter(
43-
chunk_size=100,
44-
chunk_overlap=0,
45-
)
46-
)
47-
4866
texts_splitter_results: List[Document] = Field(None, description="Text splitter results")
4967
pinecone_search: Pinecone = Field(None, description="Pinecone search")
50-
pinecone_index_name: str = Field(default="netec-ssm", description="Pinecone index name")
51-
openai_embedding: OpenAIEmbeddings = Field(default_factory=lambda: OpenAIEmbeddings(model="ada"))
68+
openai_embedding: OpenAIEmbeddings = Field(OpenAIEmbeddings())
5269
query_result: List[float] = Field(None, description="Vector database query result")
5370

5471
def cached_chat_request(self, system_message: str, human_message: str) -> SystemMessage:
@@ -68,24 +85,72 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
6885

6986
def split_text(self, text: str) -> List[Document]:
7087
"""Split text."""
71-
# pylint: disable=no-member
72-
retval = self.text_splitter.create_documents([text])
88+
text_splitter = RecursiveCharacterTextSplitter(
89+
chunk_size=100,
90+
chunk_overlap=0,
91+
)
92+
retval = text_splitter.create_documents([text])
7393
return retval
7494

7595
def embed(self, text: str) -> List[float]:
7696
"""Embed."""
77-
texts_splitter_results = self.split_text(text)
97+
text_splitter = RecursiveCharacterTextSplitter(
98+
chunk_size=100,
99+
chunk_overlap=0,
100+
)
101+
texts_splitter_results = text_splitter.create_documents([text])
78102
embedding = texts_splitter_results[0].page_content
79103
# pylint: disable=no-member
80104
self.openai_embedding.embed_query(embedding)
81105

82106
self.pinecone_search = Pinecone.from_documents(
83107
texts_splitter_results,
84108
embedding=self.openai_embedding,
85-
index_name=self.pinecone_index_name,
109+
index_name=Credentials.PINECONE_INDEX_NAME,
86110
)
87111

112+
def rag(self, filepath: str, prompt: str):
113+
"""
114+
Embed PDF.
115+
1. Load PDF document text data
116+
2. Split into pages
117+
3. Embed each page
118+
4. Store in Pinecone
119+
"""
120+
121+
# pylint: disable=unused-variable
122+
def format_docs(docs):
123+
"""Format docs."""
124+
return "\n\n".join(doc.page_content for doc in docs)
125+
126+
for pdf_file in glob.glob(os.path.join(filepath, "*.pdf")):
127+
loader = PyPDFLoader(file_path=pdf_file)
128+
docs = loader.load()
129+
for doc in docs:
130+
self.embed(doc.page_content)
131+
132+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
133+
splits = text_splitter.split_documents(docs)
134+
vectorstore = Pinecone.from_documents(documents=splits, embedding=self.openai_embedding)
135+
retriever = vectorstore.as_retriever()
136+
prompt = hub.pull("rlm/rag-prompt")
137+
138+
rag_chain = (
139+
{"context": retriever | self.format_docs, "question": RunnablePassthrough()}
140+
| prompt
141+
| self.chat
142+
| StrOutputParser()
143+
)
144+
145+
return rag_chain.invoke(prompt)
146+
88147
def embedded_prompt(self, prompt: str) -> List[Document]:
89-
"""Embedded prompt."""
148+
"""
149+
Embedded prompt.
150+
1. Retrieve prompt: Given a user input, relevant splits are retrieved
151+
from storage using a Retriever.
152+
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
153+
the question and the retrieved data
154+
"""
90155
result = self.pinecone_search.similarity_search(prompt)
91156
return result

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@ codespell==2.2.6
2222
python-dotenv==1.0.0
2323
pydantic==2.5.2
2424
langchain==0.0.343
25+
openai==1.3.5
2526
pinecone-client==2.2.4
27+
pypdf==3.17.1
28+
tiktoken==0.5.1

0 commit comments

Comments
 (0)