8
8
9
9
import glob
10
10
import os
11
- from typing import ClassVar , List
11
+ from typing import List # ClassVar
12
12
13
13
import pinecone
14
- from langchain import hub
15
14
from langchain .cache import InMemoryCache
16
15
17
16
# prompting and chat
27
26
from langchain .globals import set_llm_cache
28
27
from langchain .llms .openai import OpenAI
29
28
from langchain .prompts import PromptTemplate
30
- from langchain .schema import HumanMessage , StrOutputParser , SystemMessage
31
- from langchain .schema .runnable import RunnablePassthrough
29
+ from langchain .schema import HumanMessage , SystemMessage
32
30
from langchain .text_splitter import Document , RecursiveCharacterTextSplitter
33
31
from langchain .vectorstores .pinecone import Pinecone
34
- from pydantic import BaseModel , ConfigDict , Field # ValidationError
35
32
36
33
# this project
37
34
from models .const import Credentials
38
35
39
36
37
+ # from pydantic import BaseModel, ConfigDict, Field
38
+
39
+
40
40
###############################################################################
41
41
# initializations
42
42
###############################################################################
45
45
set_llm_cache (InMemoryCache ())
46
46
47
47
48
- class SalesSupportModel ( BaseModel ) :
48
+ class SalesSupportModel :
49
49
"""Sales Support Model (SSM)."""
50
50
51
- Config : ClassVar = ConfigDict (arbitrary_types_allowed = True )
52
-
53
51
# prompting wrapper
54
- chat : ChatOpenAI = Field (
55
- default_factory = lambda : ChatOpenAI (
56
- api_key = Credentials .OPENAI_API_KEY ,
57
- organization = Credentials .OPENAI_API_ORGANIZATION ,
58
- cache = True ,
59
- max_retries = 3 ,
60
- model = "gpt-3.5-turbo" ,
61
- temperature = 0.0 ,
62
- )
52
+ chat = ChatOpenAI (
53
+ api_key = Credentials .OPENAI_API_KEY ,
54
+ organization = Credentials .OPENAI_API_ORGANIZATION ,
55
+ cache = True ,
56
+ max_retries = 3 ,
57
+ model = "gpt-3.5-turbo" ,
58
+ temperature = 0.0 ,
63
59
)
64
60
65
61
# embeddings
66
- texts_splitter_results : List [Document ] = Field (None , description = "Text splitter results" )
67
- pinecone_search : Pinecone = Field (None , description = "Pinecone search" )
68
- openai_embedding : OpenAIEmbeddings = Field (OpenAIEmbeddings ())
69
- query_result : List [float ] = Field (None , description = "Vector database query result" )
62
+ texts_splitter_results : List [Document ]
63
+ openai_embedding = OpenAIEmbeddings ()
64
+ query_result : List [float ]
70
65
71
66
def cached_chat_request (self , system_message : str , human_message : str ) -> SystemMessage :
72
67
"""Cached chat request."""
@@ -103,13 +98,13 @@ def embed(self, text: str) -> List[float]:
103
98
# pylint: disable=no-member
104
99
self .openai_embedding .embed_query (embedding )
105
100
106
- self . pinecone_search = Pinecone .from_documents (
107
- texts_splitter_results ,
101
+ Pinecone .from_documents (
102
+ documents = texts_splitter_results ,
108
103
embedding = self .openai_embedding ,
109
104
index_name = Credentials .PINECONE_INDEX_NAME ,
110
105
)
111
106
112
- def rag (self , filepath : str , prompt : str ):
107
+ def load (self , filepath : str ):
113
108
"""
114
109
Embed PDF.
115
110
1. Load PDF document text data
@@ -118,39 +113,52 @@ def rag(self, filepath: str, prompt: str):
118
113
4. Store in Pinecone
119
114
"""
120
115
121
- # pylint: disable=unused-variable
122
- def format_docs ( docs ):
123
- """Format docs."""
124
- return " \n \n " . join ( doc . page_content for doc in docs )
125
-
126
- for pdf_file in glob . glob ( os . path . join ( filepath , "*.pdf" )):
116
+ pdf_files = glob . glob ( os . path . join ( filepath , "*.pdf" ))
117
+ i = 0
118
+ for pdf_file in pdf_files :
119
+ i += 1
120
+ j = len ( pdf_files )
121
+ print ( f"Loading PDF { i } of { j } : " )
127
122
loader = PyPDFLoader (file_path = pdf_file )
128
123
docs = loader .load ()
124
+ k = 0
129
125
for doc in docs :
126
+ k += 1
127
+ print (k * "-" , end = "\r " )
130
128
self .embed (doc .page_content )
129
+ print ("Finished loading PDFs" )
131
130
132
- text_splitter = RecursiveCharacterTextSplitter (chunk_size = 1000 , chunk_overlap = 200 )
133
- splits = text_splitter .split_documents (docs )
134
- vectorstore = Pinecone .from_documents (documents = splits , embedding = self .openai_embedding )
135
- retriever = vectorstore .as_retriever ()
136
- prompt = hub .pull ("rlm/rag-prompt" )
137
-
138
- rag_chain = (
139
- {"context" : retriever | self .format_docs , "question" : RunnablePassthrough ()}
140
- | prompt
141
- | self .chat
142
- | StrOutputParser ()
143
- )
144
-
145
- return rag_chain .invoke (prompt )
146
-
147
- def embedded_prompt (self , prompt : str ) -> List [Document ]:
131
+ def rag (self , prompt : str ):
148
132
"""
149
133
Embedded prompt.
150
134
1. Retrieve prompt: Given a user input, relevant splits are retrieved
151
135
from storage using a Retriever.
152
136
2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
153
137
the question and the retrieved data
154
138
"""
155
- result = self .pinecone_search .similarity_search (prompt )
156
- return result
139
+
140
+ # pylint: disable=unused-variable
141
+ def format_docs (docs ):
142
+ """Format docs."""
143
+ return "\n \n " .join (doc .page_content for doc in docs )
144
+
145
+ pinecone_search = Pinecone .from_existing_index (
146
+ Credentials .PINECONE_INDEX_NAME ,
147
+ embedding = self .openai_embedding ,
148
+ )
149
+ retriever = pinecone_search .as_retriever ()
150
+
151
+ # Use the retriever to get relevant documents
152
+ documents = retriever .get_relevant_documents (query = prompt )
153
+ print (f"Retrieved { len (documents )} related documents from Pinecone" )
154
+
155
+ # Generate a prompt from the retrieved documents
156
+ prompt += " " .join (doc .page_content for doc in documents )
157
+ print (f"Prompt contains { len (prompt .split ())} words" )
158
+ print ("Prompt:" , prompt )
159
+ print (doc for doc in documents )
160
+
161
+ # Get a response from the GPT-3.5-turbo model
162
+ response = self .cached_chat_request (system_message = "You are a helpful assistant." , human_message = prompt )
163
+
164
+ return response
0 commit comments