Skip to content

Commit 1502dda

Browse files
committed
update of the node
1 parent fdce0a9 commit 1502dda

File tree

3 files changed

+58
-8
lines changed

3 files changed

+58
-8
lines changed

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
import requests
1+
"""
2+
Generate answer_node
3+
"""
4+
import re
25
import json
36
from typing import List, Optional
7+
import requests
8+
from tqdm import tqdm
49
from langchain.prompts import PromptTemplate
510
from langchain_core.output_parsers import JsonOutputParser
611
from langchain_core.runnables import RunnableParallel
712
from langchain_openai import ChatOpenAI
813
from langchain_community.chat_models import ChatOllama
9-
from tqdm import tqdm
1014
from ..utils.logging import get_logger
15+
from ..utils import parse_response_to_dict
1116
from .base_node import BaseNode
1217
from ..prompts import (
1318
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
@@ -52,6 +57,8 @@ def __init__(
5257
self.additional_info = node_config.get("additional_info", "")
5358
self.api_key = node_config.get("config", {}).get("llm", {}).get("api_key", "")
5459

60+
61+
5562
def execute(self, state: dict) -> dict:
5663
self.logger.info(f"--- Executing {self.node_name} Node ---")
5764

@@ -86,7 +93,10 @@ def execute(self, state: dict) -> dict:
8693
"messages": [{"role": "user", "content": prompt}],
8794
"temperature": 0
8895
}, timeout=10)
89-
state.update({self.output[0]: response.json()})
96+
97+
response_text = response.json()['choices'][0]['message']['content']
98+
cleaned_response = parse_response_to_dict(response_text)
99+
state.update({self.output[0]: cleaned_response})
90100
return state
91101

92102
chunks_responses = []
@@ -105,7 +115,7 @@ def execute(self, state: dict) -> dict:
105115
"temperature": 0
106116
}, timeout=10)
107117
chunk_response = response.json()['choices'][0]['message']['content']
108-
cleaned_chunk_response = json.loads(chunk_response.replace('\\n', '').replace('\\', ''))
118+
cleaned_chunk_response = parse_response_to_dict(chunk_response)
109119
chunks_responses.append(cleaned_chunk_response)
110120

111121
merge_context = " ".join([json.dumps(chunk) for chunk in chunks_responses])
@@ -120,7 +130,7 @@ def execute(self, state: dict) -> dict:
120130
"temperature": 0
121131
}, timeout=10)
122132
response_text = response.json()['choices'][0]['message']['content']
123-
cleaned_response = json.loads(response_text.replace('\\n', '').replace('\\', ''))
133+
cleaned_response = parse_response_to_dict(response_text)
124134
state.update({self.output[0]: cleaned_response})
125135
return state
126136

@@ -146,11 +156,14 @@ def execute(self, state: dict) -> dict:
146156
return state
147157

148158
chains_dict = {}
149-
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
159+
for i, chunk in enumerate(tqdm(doc,
160+
desc="Processing chunks",
161+
disable=not self.verbose)):
150162
prompt = PromptTemplate(
151163
template=templates['chunks'],
152164
input_variables=["question"],
153-
partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions}
165+
partial_variables={"context": chunk, "chunk_id": i + 1,
166+
"format_instructions": format_instructions}
154167
)
155168
chain_name = f"chunk{i+1}"
156169
chains_dict[chain_name] = prompt | self.llm_model | output_parser

scrapegraphai/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
__init__.py file for utils folder
33
"""
4-
4+
from .response_to_dict import parse_response_to_dict
55
from .convert_to_csv import convert_to_csv
66
from .convert_to_json import convert_to_json
77
from .prettify_exec_info import prettify_exec_info
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
parse_response_to_dict module
3+
"""
4+
import re
5+
import json
6+
7+
def parse_response_to_dict(response_text: str) -> dict:
8+
"""
9+
Parse the response text to a dictionary, handling different formats.
10+
11+
Args:
12+
response_text (str): The raw text response from the model.
13+
14+
Returns:
15+
dict: The parsed dictionary.
16+
"""
17+
# Regex to capture text between ```json and ```
18+
json_pattern = r'```json\s*(.*?)\s*```'
19+
20+
# Check if response matches the pattern
21+
match = re.search(json_pattern, response_text, re.DOTALL)
22+
if match:
23+
json_str = match.group(1)
24+
else:
25+
# If no match, consider the whole response as potential JSON
26+
json_str = response_text
27+
28+
# Clean any common escape characters and whitespace issues
29+
cleaned_json_str = json_str.replace('\\n', '').replace('\\', '').strip()
30+
31+
# Parse the cleaned string into a dictionary
32+
try:
33+
parsed_dict = json.loads(cleaned_json_str)
34+
except json.JSONDecodeError:
35+
raise ValueError("The response could not be parsed into a valid JSON dictionary.")
36+
37+
return parsed_dict

0 commit comments

Comments
 (0)