12
12
from streamlit_chat import message
13
13
14
14
from agent import AgentHelper
15
- from docGPT import DocGPT
15
+ from docGPT import DocGPT , OpenAiAPI , SerpAPI
16
16
from model import PDFLoader
17
17
18
18
langchain .llm_cache = InMemoryCache ()
@@ -35,9 +35,6 @@ def theme():
35
35
st .title ('PDF Chatbot' )
36
36
37
37
38
- theme ()
39
-
40
-
41
38
def load_api_key () -> None :
42
39
with st .sidebar :
43
40
if st .session_state .openai_api_key :
@@ -70,10 +67,7 @@ def load_api_key() -> None:
70
67
os .environ ['SERPAPI_API_KEY' ] = SERPAPI_API_KEY
71
68
72
69
73
- load_api_key ()
74
-
75
-
76
- with st .container ():
70
+ def upload_and_process_pdf ():
77
71
upload_file = st .file_uploader ('#### Upload a PDF file:' , type = 'pdf' )
78
72
if upload_file :
79
73
temp_file = tempfile .NamedTemporaryFile (delete = False )
@@ -87,10 +81,33 @@ def load_api_key() -> None:
87
81
if temp_file_path :
88
82
os .remove (temp_file_path )
89
83
90
- docGPT_tool , calculate_tool , search_tool , llm_tool = [None ] * 4
84
+ return docs
85
+
86
+
87
+ @lru_cache (maxsize = 20 )
88
+ def get_response (query : str ):
89
+ try :
90
+ if agent_ .agent_ is not None :
91
+ response = agent_ .query (query )
92
+ return response
93
+ except Exception as e :
94
+ app_logger .info (e )
95
+
96
+
97
+ theme ()
98
+ load_api_key ()
99
+
100
+ doc_container = st .container ()
101
+
102
+
103
+ with doc_container :
104
+ docs = upload_and_process_pdf ()
105
+
106
+ agent_ , docGPT_tool , calculate_tool , search_tool , llm_tool = [None ]* 5
107
+ if OpenAiAPI .is_valid ():
108
+ agent_ = AgentHelper ()
91
109
92
- try :
93
- agent_ = AgentHelper ()
110
+ if docs :
94
111
docGPT = DocGPT (docs = docs )
95
112
docGPT .create_qa_chain (
96
113
chain_type = 'refine' ,
@@ -100,59 +117,37 @@ def load_api_key() -> None:
100
117
calculate_tool = agent_ .get_calculate_chain
101
118
llm_tool = agent_ .create_llm_chain ()
102
119
103
- except Exception as e :
104
- app_logger . info ( e )
120
+ if SerpAPI . is_valid () :
121
+ search_tool = agent_ . get_searp_chain
105
122
106
- try :
107
- search_tool = agent_ .get_searp_chain
108
- except Exception as e :
109
- app_logger .info (e )
110
-
111
- try :
112
- tools = [
113
- docGPT_tool ,
114
- search_tool ,
115
- # llm_tool, # This will cause agent confuse
116
- calculate_tool
117
- ]
118
- agent_ .initialize (tools )
119
- except Exception as e :
120
- app_logger .info (e )
121
-
122
-
123
- if not st .session_state ['openai_api_key' ]:
124
- st .error ('⚠️ :red[You have not pass OpenAPI key. (Or your api key cannot use.)] Necessary Pass' )
125
-
126
- if not st .session_state ['serpapi_api_key' ]:
127
- st .warning ('⚠️ You have not pass SEARPAPI key. (You cannot ask current events.)' )
123
+ try :
124
+ tools = [
125
+ docGPT_tool ,
126
+ search_tool ,
127
+ # llm_tool, # This will cause agent confuse
128
+ calculate_tool
129
+ ]
130
+ agent_ .initialize (tools )
131
+ except Exception as e :
132
+ app_logger .info (e )
128
133
129
- st .write ('---' )
134
+ st .write ('---' )
130
135
131
136
if 'response' not in st .session_state :
132
137
st .session_state ['response' ] = ['How can I help you?' ]
133
138
134
139
if 'query' not in st .session_state :
135
140
st .session_state ['query' ] = ['Hi' ]
136
141
137
-
138
- @lru_cache (maxsize = 20 )
139
- def get_response (query : str ):
140
- try :
141
- if agent_ .agent_ is not None :
142
- response = agent_ .query (query )
143
- return response
144
- except Exception as e :
145
- app_logger .info (e )
146
-
147
- query = st .text_input (
148
- "#### Question:" ,
149
- placeholder = 'Enter your question'
150
- )
151
-
152
- response_container = st .container ()
153
142
user_container = st .container ()
143
+ response_container = st .container ()
154
144
155
145
with user_container :
146
+ query = st .text_input (
147
+ "#### Question:" ,
148
+ placeholder = 'Enter your question'
149
+ )
150
+
156
151
if query and query != '' :
157
152
response = get_response (query )
158
153
st .session_state .query .append (query )
@@ -161,5 +156,19 @@ def get_response(query: str):
161
156
with response_container :
162
157
if st .session_state ['response' ]:
163
158
for i in range (len (st .session_state ['response' ])- 1 , - 1 , - 1 ):
164
- message (st .session_state ["response" ][i ], key = str (i ))
165
- message (st .session_state ['query' ][i ], is_user = True , key = str (i ) + '_user' )
159
+ message (
160
+ st .session_state ["response" ][i ], key = str (i ),
161
+ logo = (
162
+ 'https://api.dicebear.com/6.x/bottts/svg?'
163
+ 'baseColor=fb8c00&eyes=bulging'
164
+ )
165
+ )
166
+ message (
167
+ st .session_state ['query' ][i ], is_user = True , key = str (i ) + '_user' ,
168
+ logo = (
169
+ 'https://api.dicebear.com/6.x/adventurer/svg?'
170
+ 'hair=short16&hairColor=85c2c6&'
171
+ 'eyes=variant12&size=100&'
172
+ 'mouth=variant26&skinColor=f2d3b1'
173
+ )
174
+ )
0 commit comments