Skip to content

Commit 96f8364

Browse files
committed
Update generate_answer_node.py
1 parent 1502dda commit 96f8364

File tree

1 file changed

+70
-57
lines changed

1 file changed

+70
-57
lines changed

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
"""
2-
Generate answer_node
3-
"""
41
import re
52
import json
63
from typing import List, Optional
74
import requests
5+
import asyncio
86
from tqdm import tqdm
97
from langchain.prompts import PromptTemplate
108
from langchain_core.output_parsers import JsonOutputParser
@@ -25,16 +23,6 @@ class GenerateAnswerNode(BaseNode):
2523
and the content extracted from a webpage. It constructs a prompt from the user's input
2624
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
2725
an answer.
28-
29-
Attributes:
30-
llm_model: An instance of a language model client, configured for generating answers.
31-
verbose (bool): A flag indicating whether to show print statements during execution.
32-
33-
Args:
34-
input (str): Boolean expression defining the input keys needed from the state.
35-
output (List[str]): List of output keys to be updated in the state.
36-
node_config (dict): Additional configuration for the node.
37-
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
3826
"""
3927

4028
def __init__(
@@ -57,7 +45,33 @@ def __init__(
5745
self.additional_info = node_config.get("additional_info", "")
5846
self.api_key = node_config.get("config", {}).get("llm", {}).get("api_key", "")
5947

48+
async def _process_chunks_async(self, chunks, templates, user_prompt, format_instructions):
49+
async def send_request(prompt):
50+
url = "https://api.openai.com/v1/chat/completions"
51+
headers = {
52+
"Content-Type": "application/json",
53+
"Authorization": f"Bearer {self.api_key}"
54+
}
55+
response = await requests.post(url, headers=headers, json={
56+
"model": self.llm_model.model_name,
57+
"messages": [{"role": "user", "content": prompt}],
58+
"temperature": 0
59+
}, timeout=10)
60+
response_text = response.json()['choices'][0]['message']['content']
61+
return parse_response_to_dict(response_text)
6062

63+
tasks = []
64+
for i, chunk in enumerate(chunks):
65+
prompt = templates['chunks'].format(
66+
question=user_prompt,
67+
context=chunk,
68+
chunk_id=i + 1,
69+
format_instructions=format_instructions
70+
)
71+
tasks.append(send_request(prompt))
72+
73+
results = await asyncio.gather(*tasks)
74+
return results
6175

6276
def execute(self, state: dict) -> dict:
6377
self.logger.info(f"--- Executing {self.node_name} Node ---")
@@ -76,63 +90,60 @@ def execute(self, state: dict) -> dict:
7690
'merge': TEMPLATE_MERGE_MD
7791
}
7892

79-
url = "https://api.openai.com/v1/chat/completions"
80-
headers = {
81-
"Content-Type": "application/json",
82-
"Authorization": f"Bearer {self.api_key}"
83-
}
84-
8593
if len(doc) == 1:
8694
prompt = templates['no_chunks'].format(
8795
question=user_prompt,
8896
context=doc[0],
8997
format_instructions=format_instructions
9098
)
91-
response = requests.post(url, headers=headers, json={
92-
"model": self.llm_model.model_name,
93-
"messages": [{"role": "user", "content": prompt}],
94-
"temperature": 0
95-
}, timeout=10)
99+
response = requests.post(
100+
url="https://api.openai.com/v1/chat/completions",
101+
headers={
102+
"Content-Type": "application/json",
103+
"Authorization": f"Bearer {self.api_key}"
104+
},
105+
json={
106+
"model": self.llm_model.model_name,
107+
"messages": [{"role": "user", "content": prompt}],
108+
"temperature": 0
109+
},
110+
timeout=10
111+
)
96112

97113
response_text = response.json()['choices'][0]['message']['content']
98114
cleaned_response = parse_response_to_dict(response_text)
99115
state.update({self.output[0]: cleaned_response})
100116
return state
101117

102-
chunks_responses = []
103-
for i, chunk in enumerate(
104-
tqdm(doc, desc="Processing chunks",
105-
disable=not self.verbose)):
106-
prompt = templates['chunks'].format(
118+
else:
119+
chunks_responses = asyncio.run(
120+
self._process_chunks_async(doc, templates, user_prompt, format_instructions)
121+
)
122+
123+
merge_context = " ".join([json.dumps(chunk) for chunk in chunks_responses])
124+
merge_prompt = templates['merge'].format(
107125
question=user_prompt,
108-
context=chunk,
109-
chunk_id=i + 1,
126+
context=merge_context,
110127
format_instructions=format_instructions
111128
)
112-
response = requests.post(url, headers=headers, json={
113-
"model": self.llm_model.model_name,
114-
"messages": [{"role": "user", "content": prompt}],
115-
"temperature": 0
116-
}, timeout=10)
117-
chunk_response = response.json()['choices'][0]['message']['content']
118-
cleaned_chunk_response = parse_response_to_dict(chunk_response)
119-
chunks_responses.append(cleaned_chunk_response)
120-
121-
merge_context = " ".join([json.dumps(chunk) for chunk in chunks_responses])
122-
merge_prompt = templates['merge'].format(
123-
question=user_prompt,
124-
context=merge_context,
125-
format_instructions=format_instructions
126-
)
127-
response = requests.post(url, headers=headers, json={
128-
"model": self.llm_model.model_name,
129-
"messages": [{"role": "user", "content": merge_prompt}],
130-
"temperature": 0
131-
}, timeout=10)
132-
response_text = response.json()['choices'][0]['message']['content']
133-
cleaned_response = parse_response_to_dict(response_text)
134-
state.update({self.output[0]: cleaned_response})
135-
return state
129+
response = requests.post(
130+
url="https://api.openai.com/v1/chat/completions",
131+
headers={
132+
"Content-Type": "application/json",
133+
"Authorization": f"Bearer {self.api_key}"
134+
},
135+
json={
136+
"model": self.llm_model.model_name,
137+
"messages": [{"role": "user", "content": merge_prompt}],
138+
"temperature": 0
139+
},
140+
timeout=10
141+
)
142+
143+
response_text = response.json()['choices'][0]['message']['content']
144+
cleaned_response = parse_response_to_dict(response_text)
145+
state.update({self.output[0]: cleaned_response})
146+
return state
136147

137148
else:
138149
templates = {
@@ -142,13 +153,15 @@ def execute(self, state: dict) -> dict:
142153
}
143154

144155
if self.additional_info:
145-
templates = {key: self.additional_info + template for key, template in templates.items()}
156+
templates = {key: self.additional_info +
157+
template for key, template in templates.items()}
146158

147159
if len(doc) == 1:
148160
prompt = PromptTemplate(
149161
template=templates['no_chunks'],
150162
input_variables=["question"],
151-
partial_variables={"context": doc, "format_instructions": format_instructions}
163+
partial_variables={"context": doc[0],
164+
"format_instructions": format_instructions}
152165
)
153166
chain = prompt | self.llm_model | output_parser
154167
answer = chain.invoke({"question": user_prompt})

0 commit comments

Comments
 (0)