1
1
# -*- coding: utf-8 -*-
2
2
# pylint: disable=too-few-public-methods
3
- """Sales Support Model (SSM) for the LangChain project."""
4
-
3
+ """
4
+ Sales Support Model (SSM) for the LangChain project.
5
+ See: https://python.langchain.com/docs/modules/model_io/llms/llm_caching
6
+ https://python.langchain.com/docs/modules/data_connection/document_loaders/pdf
7
+ """
8
+
9
+ import glob
10
+ import os
5
11
from typing import ClassVar , List
6
12
7
13
import pinecone
14
+ from langchain import hub
15
+ from langchain .cache import InMemoryCache
16
+
17
+ # prompting and chat
8
18
from langchain .chat_models import ChatOpenAI
19
+
20
+ # document loading
21
+ from langchain .document_loaders import PyPDFLoader
22
+
23
+ # embedding
9
24
from langchain .embeddings import OpenAIEmbeddings
25
+
26
+ # vector database
27
+ from langchain .globals import set_llm_cache
10
28
from langchain .llms .openai import OpenAI
11
29
from langchain .prompts import PromptTemplate
12
- from langchain .schema import HumanMessage , SystemMessage # AIMessage (not used)
30
+ from langchain .schema import HumanMessage , StrOutputParser , SystemMessage
31
+ from langchain .schema .runnable import RunnablePassthrough
13
32
from langchain .text_splitter import Document , RecursiveCharacterTextSplitter
14
33
from langchain .vectorstores .pinecone import Pinecone
15
34
from pydantic import BaseModel , ConfigDict , Field # ValidationError
16
35
36
+ # this project
17
37
from models .const import Credentials
18
38
19
39
40
+ ###############################################################################
41
+ # initializations
42
+ ###############################################################################
20
43
DEFAULT_MODEL_NAME = "text-davinci-003"
21
44
pinecone .init (api_key = Credentials .PINECONE_API_KEY , environment = Credentials .PINECONE_ENVIRONMENT )
45
+ set_llm_cache (InMemoryCache ())
22
46
23
47
24
48
class SalesSupportModel (BaseModel ):
@@ -31,24 +55,17 @@ class SalesSupportModel(BaseModel):
31
55
default_factory = lambda : ChatOpenAI (
32
56
api_key = Credentials .OPENAI_API_KEY ,
33
57
organization = Credentials .OPENAI_API_ORGANIZATION ,
58
+ cache = True ,
34
59
max_retries = 3 ,
35
60
model = "gpt-3.5-turbo" ,
36
- temperature = 0.3 ,
61
+ temperature = 0.0 ,
37
62
)
38
63
)
39
64
40
65
# embeddings
41
- text_splitter : RecursiveCharacterTextSplitter = Field (
42
- default_factory = lambda : RecursiveCharacterTextSplitter (
43
- chunk_size = 100 ,
44
- chunk_overlap = 0 ,
45
- )
46
- )
47
-
48
66
texts_splitter_results : List [Document ] = Field (None , description = "Text splitter results" )
49
67
pinecone_search : Pinecone = Field (None , description = "Pinecone search" )
50
- pinecone_index_name : str = Field (default = "netec-ssm" , description = "Pinecone index name" )
51
- openai_embedding : OpenAIEmbeddings = Field (default_factory = lambda : OpenAIEmbeddings (model = "ada" ))
68
+ openai_embedding : OpenAIEmbeddings = Field (OpenAIEmbeddings ())
52
69
query_result : List [float ] = Field (None , description = "Vector database query result" )
53
70
54
71
def cached_chat_request (self , system_message : str , human_message : str ) -> SystemMessage :
@@ -68,24 +85,72 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
68
85
69
86
def split_text (self , text : str ) -> List [Document ]:
70
87
"""Split text."""
71
- # pylint: disable=no-member
72
- retval = self .text_splitter .create_documents ([text ])
88
+ text_splitter = RecursiveCharacterTextSplitter (
89
+ chunk_size = 100 ,
90
+ chunk_overlap = 0 ,
91
+ )
92
+ retval = text_splitter .create_documents ([text ])
73
93
return retval
74
94
75
95
def embed (self , text : str ) -> List [float ]:
76
96
"""Embed."""
77
- texts_splitter_results = self .split_text (text )
97
+ text_splitter = RecursiveCharacterTextSplitter (
98
+ chunk_size = 100 ,
99
+ chunk_overlap = 0 ,
100
+ )
101
+ texts_splitter_results = text_splitter .create_documents ([text ])
78
102
embedding = texts_splitter_results [0 ].page_content
79
103
# pylint: disable=no-member
80
104
self .openai_embedding .embed_query (embedding )
81
105
82
106
self .pinecone_search = Pinecone .from_documents (
83
107
texts_splitter_results ,
84
108
embedding = self .openai_embedding ,
85
- index_name = self . pinecone_index_name ,
109
+ index_name = Credentials . PINECONE_INDEX_NAME ,
86
110
)
87
111
112
+ def rag (self , filepath : str , prompt : str ):
113
+ """
114
+ Embed PDF.
115
+ 1. Load PDF document text data
116
+ 2. Split into pages
117
+ 3. Embed each page
118
+ 4. Store in Pinecone
119
+ """
120
+
121
+ # pylint: disable=unused-variable
122
+ def format_docs (docs ):
123
+ """Format docs."""
124
+ return "\n \n " .join (doc .page_content for doc in docs )
125
+
126
+ for pdf_file in glob .glob (os .path .join (filepath , "*.pdf" )):
127
+ loader = PyPDFLoader (file_path = pdf_file )
128
+ docs = loader .load ()
129
+ for doc in docs :
130
+ self .embed (doc .page_content )
131
+
132
+ text_splitter = RecursiveCharacterTextSplitter (chunk_size = 1000 , chunk_overlap = 200 )
133
+ splits = text_splitter .split_documents (docs )
134
+ vectorstore = Pinecone .from_documents (documents = splits , embedding = self .openai_embedding )
135
+ retriever = vectorstore .as_retriever ()
136
+ prompt = hub .pull ("rlm/rag-prompt" )
137
+
138
+ rag_chain = (
139
+ {"context" : retriever | self .format_docs , "question" : RunnablePassthrough ()}
140
+ | prompt
141
+ | self .chat
142
+ | StrOutputParser ()
143
+ )
144
+
145
+ return rag_chain .invoke (prompt )
146
+
88
147
def embedded_prompt (self , prompt : str ) -> List [Document ]:
89
- """Embedded prompt."""
148
+ """
149
+ Embedded prompt.
150
+ 1. Retrieve prompt: Given a user input, relevant splits are retrieved
151
+ from storage using a Retriever.
152
+ 2. Generate: A ChatModel / LLM produces an answer using a prompt that includes
153
+ the question and the retrieved data
154
+ """
90
155
result = self .pinecone_search .similarity_search (prompt )
91
156
return result
0 commit comments