1
1
# -*- coding: utf-8 -*-
2
2
# pylint: disable=too-few-public-methods
3
- """Sales Support Model (SSM) for the LangChain project."""
3
+ """
4
+ Sales Support Model (SSM) for the LangChain project.
5
+ See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
6
+ https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
7
+ """
4
8
5
- from typing import ClassVar , List
9
+ import glob
10
+ import os
11
+ from typing import List # ClassVar
6
12
7
13
import pinecone
14
+ from langchain .cache import InMemoryCache
15
+
16
+ # prompting and chat
8
17
from langchain .chat_models import ChatOpenAI
18
+
19
+ # document loading
20
+ from langchain .document_loaders import PyPDFLoader
21
+
22
+ # embedding
9
23
from langchain .embeddings import OpenAIEmbeddings
24
+
25
+ # vector database
26
+ from langchain .globals import set_llm_cache
10
27
from langchain .llms .openai import OpenAI
11
28
from langchain .prompts import PromptTemplate
12
- from langchain .schema import HumanMessage , SystemMessage # AIMessage (not used)
29
+ from langchain .schema import HumanMessage , SystemMessage
13
30
from langchain .text_splitter import Document , RecursiveCharacterTextSplitter
14
31
from langchain .vectorstores .pinecone import Pinecone
15
- from pydantic import BaseModel , ConfigDict , Field # ValidationError
16
32
33
+ # this project
17
34
from models .const import Credentials
18
35
19
36
37
+ # from pydantic import BaseModel, ConfigDict, Field
38
+
39
+
40
+ ###############################################################################
41
+ # initializations
42
+ ###############################################################################
20
43
DEFAULT_MODEL_NAME = "text-davinci-003"
21
44
pinecone .init (api_key = Credentials .PINECONE_API_KEY , environment = Credentials .PINECONE_ENVIRONMENT )
45
+ set_llm_cache (InMemoryCache ())
22
46
23
47
24
- class SalesSupportModel ( BaseModel ) :
48
+ class SalesSupportModel :
25
49
"""Sales Support Model (SSM)."""
26
50
27
- Config : ClassVar = ConfigDict (arbitrary_types_allowed = True )
28
-
29
51
# prompting wrapper
30
- chat : ChatOpenAI = Field (
31
- default_factory = lambda : ChatOpenAI (
32
- api_key = Credentials .OPENAI_API_KEY ,
33
- organization = Credentials .OPENAI_API_ORGANIZATION ,
34
- max_retries = 3 ,
35
- model = "gpt-3.5-turbo" ,
36
- temperature = 0.3 ,
37
- )
52
+ chat = ChatOpenAI (
53
+ api_key = Credentials .OPENAI_API_KEY ,
54
+ organization = Credentials .OPENAI_API_ORGANIZATION ,
55
+ cache = True ,
56
+ max_retries = 3 ,
57
+ model = "gpt-3.5-turbo" ,
58
+ temperature = 0.0 ,
38
59
)
39
60
40
61
# embeddings
41
- text_splitter : RecursiveCharacterTextSplitter = Field (
42
- default_factory = lambda : RecursiveCharacterTextSplitter (
43
- chunk_size = 100 ,
44
- chunk_overlap = 0 ,
45
- )
62
+ text_splitter = RecursiveCharacterTextSplitter (
63
+ chunk_size = 100 ,
64
+ chunk_overlap = 0 ,
65
+ )
66
+ openai_embedding = OpenAIEmbeddings ()
67
+ pinecone_search = Pinecone .from_existing_index (
68
+ Credentials .PINECONE_INDEX_NAME ,
69
+ embedding = openai_embedding ,
46
70
)
47
-
48
- texts_splitter_results : List [Document ] = Field (None , description = "Text splitter results" )
49
- pinecone_search : Pinecone = Field (None , description = "Pinecone search" )
50
- pinecone_index_name : str = Field (default = "netec-ssm" , description = "Pinecone index name" )
51
- openai_embedding : OpenAIEmbeddings = Field (default_factory = lambda : OpenAIEmbeddings (model = "ada" ))
52
- query_result : List [float ] = Field (None , description = "Vector database query result" )
53
71
54
72
def cached_chat_request (self , system_message : str , human_message : str ) -> SystemMessage :
55
73
"""Cached chat request."""
@@ -68,24 +86,70 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
68
86
69
87
def split_text (self , text : str ) -> List [Document ]:
70
88
"""Split text."""
71
- # pylint: disable=no-member
72
- retval = self .text_splitter .create_documents ([text ])
73
- return retval
74
-
75
- def embed (self , text : str ) -> List [float ]:
76
- """Embed."""
77
- texts_splitter_results = self .split_text (text )
78
- embedding = texts_splitter_results [0 ].page_content
79
- # pylint: disable=no-member
80
- self .openai_embedding .embed_query (embedding )
81
-
82
- self .pinecone_search = Pinecone .from_documents (
83
- texts_splitter_results ,
84
- embedding = self .openai_embedding ,
85
- index_name = self .pinecone_index_name ,
89
+ text_splitter = RecursiveCharacterTextSplitter (
90
+ chunk_size = 100 ,
91
+ chunk_overlap = 0 ,
86
92
)
93
+ retval = text_splitter .create_documents ([text ])
94
+ return retval
87
95
88
- def embedded_prompt (self , prompt : str ) -> List [Document ]:
89
- """Embedded prompt."""
90
- result = self .pinecone_search .similarity_search (prompt )
91
- return result
96
+ def load (self , filepath : str ):
97
+ """
98
+ Embed PDF.
99
+ 1. Load PDF document text data
100
+ 2. Split into pages
101
+ 3. Embed each page
102
+ 4. Store in Pinecone
103
+ """
104
+
105
+ pdf_files = glob .glob (os .path .join (filepath , "*.pdf" ))
106
+ i = 0
107
+ for pdf_file in pdf_files :
108
+ i += 1
109
+ j = len (pdf_files )
110
+ print (f"Loading PDF { i } of { j } : " , pdf_file )
111
+ loader = PyPDFLoader (file_path = pdf_file )
112
+ docs = loader .load ()
113
+ k = 0
114
+ for doc in docs :
115
+ k += 1
116
+ print (k * "-" , end = "\r " )
117
+ texts_splitter_results = self .text_splitter .create_documents ([doc .page_content ])
118
+ self .pinecone_search .from_existing_index (
119
+ index_name = Credentials .PINECONE_INDEX_NAME ,
120
+ embedding = self .openai_embedding ,
121
+ text_key = texts_splitter_results ,
122
+ )
123
+
124
+ print ("Finished loading PDFs" )
125
+
126
+ def rag (self , prompt : str ):
127
+ """
128
+ Embedded prompt.
129
+ 1. Retrieve prompt: Given a user input, relevant splits are retrieved
130
+ from storage using a Retriever.
131
+ 2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
132
+ the question and the retrieved data
133
+ """
134
+
135
+ # pylint: disable=unused-variable
136
+ def format_docs (docs ):
137
+ """Format docs."""
138
+ return "\n \n " .join (doc .page_content for doc in docs )
139
+
140
+ retriever = self .pinecone_search .as_retriever ()
141
+
142
+ # Use the retriever to get relevant documents
143
+ documents = retriever .get_relevant_documents (query = prompt )
144
+ print (f"Retrieved { len (documents )} related documents from Pinecone" )
145
+
146
+ # Generate a prompt from the retrieved documents
147
+ prompt += " " .join (doc .page_content for doc in documents )
148
+ print (f"Prompt contains { len (prompt .split ())} words" )
149
+ print ("Prompt:" , prompt )
150
+ print (doc for doc in documents )
151
+
152
+ # Get a response from the GPT-3.5-turbo model
153
+ response = self .cached_chat_request (system_message = "You are a helpful assistant." , human_message = prompt )
154
+
155
+ return response
0 commit comments