Skip to content

Commit a26589b

Browse files
committed
Add: LLM model in agent
1 parent 2025fe7 commit a26589b

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

agent/agent.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from langchain import LLMMathChain, SerpAPIWrapper
66
from langchain.agents import AgentType, Tool, initialize_agent
77
from langchain.callbacks import get_openai_callback
8+
from langchain.chains import LLMChain
89
from langchain.llms import OpenAI
9-
10+
from langchain.prompts import PromptTemplate
1011

1112
openai.api_key = os.getenv('OPENAI_API_KEY')
1213
os.environ['SERPAPI_API_KEY'] = os.getenv('SERPAPI_API_KEY')
@@ -51,6 +52,20 @@ def create_doc_chat(self, docGPT) -> Tool:
5152
)
5253
return tool
5354

55+
def create_llm_chain(self) -> Tool:
56+
"""Add a llm tool"""
57+
prompt = PromptTemplate(
58+
input_variables = ['query'],
59+
template = '{query}'
60+
)
61+
llm_chain = LLMChain(llm=self.llm, prompt = prompt)
62+
63+
tool = Tool(
64+
name='LLM',
65+
func=llm_chain.run,
66+
description='useful for general purpose queries and logic'
67+
)
68+
return tool
5469
def initialize(self, tools):
5570
for tool in tools:
5671
if isinstance(tool, Tool):
@@ -66,6 +81,9 @@ def initialize(self, tools):
6681
def query(self, query: str) -> Optional[str]:
6782
response = None
6883
with get_openai_callback() as callback:
84+
# TODO: The true result will hide in 'Observation'
85+
# https://github.com/hwchase17/langchain/issues/4916
86+
# https://python.langchain.com/docs/modules/agents/how_to/intermediate_steps
6987
response = self.agent_.run(query)
7088
print(callback)
7189
return response

app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def load_api_key() -> None:
9696
docGPT.create_qa_chain(
9797
chain_type='refine',
9898
)
99+
99100
docGPT_tool = agent_.create_doc_chat(docGPT)
101+
calculate_tool = agent_.get_calculate_chain
102+
llm_tool = agent_.create_llm_chain()
100103

101104
except Exception as e:
102105
print(e)
@@ -107,11 +110,10 @@ def load_api_key() -> None:
107110
print(e)
108111

109112
try:
110-
calculate_tool = agent_.get_calculate_chain
111-
112113
tools = [
113114
docGPT_tool,
114-
search_tool
115+
search_tool,
116+
llm_tool
115117
]
116118
agent_.initialize(tools)
117119
except Exception as e:

0 commit comments

Comments
 (0)