Skip to content

Commit f6ef909

Browse files
committed
add conf for using request inside openai
1 parent 8d6c0b7 commit f6ef909

File tree

4 files changed

+122
-157
lines changed

4 files changed

+122
-157
lines changed

examples/openai/smart_scraper_openai.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def _create_graph(self) -> BaseGraph:
8585
"llm_model": self.llm_model,
8686
"additional_info": self.config.get("additional_info"),
8787
"schema": self.schema,
88+
"config": self.config
8889
}
8990
)
9091

Lines changed: 119 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
"""
2-
GenerateAnswerNode Module
3-
"""
1+
import requests
2+
import json
43
from typing import List, Optional
54
from langchain.prompts import PromptTemplate
65
from langchain_core.output_parsers import JsonOutputParser
@@ -10,7 +9,10 @@
109
from tqdm import tqdm
1110
from ..utils.logging import get_logger
1211
from .base_node import BaseNode
13-
from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
12+
from ..prompts import (
13+
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
14+
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
15+
)
1416

1517
class GenerateAnswerNode(BaseNode):
1618
"""
@@ -39,110 +41,130 @@ def __init__(
3941
):
4042
super().__init__(node_name, "node", input, output, 2, node_config)
4143

42-
self.llm_model = node_config["llm_model"]
44+
self.llm_model = node_config.get("llm_model")
45+
if isinstance(self.llm_model, ChatOllama):
46+
self.llm_model.format = "json"
4347

44-
if isinstance(node_config["llm_model"], ChatOllama):
45-
self.llm_model.format="json"
46-
47-
self.verbose = (
48-
True if node_config is None else node_config.get("verbose", False)
49-
)
50-
self.force = (
51-
False if node_config is None else node_config.get("force", False)
52-
)
53-
self.script_creator = (
54-
False if node_config is None else node_config.get("script_creator", False)
55-
)
56-
self.is_md_scraper = (
57-
False if node_config is None else node_config.get("is_md_scraper", False)
58-
)
59-
60-
self.additional_info = node_config.get("additional_info")
48+
self.verbose = node_config.get("verbose", False)
49+
self.force = node_config.get("force", False)
50+
self.script_creator = node_config.get("script_creator", False)
51+
self.is_md_scraper = node_config.get("is_md_scraper", False)
52+
self.additional_info = node_config.get("additional_info", "")
53+
self.api_key = node_config.get("config", {}).get("llm", {}).get("api_key", "")
6154

6255
def execute(self, state: dict) -> dict:
63-
"""
64-
Generates an answer by constructing a prompt from the user's input and the scraped
65-
content, querying the language model, and parsing its response.
66-
67-
Args:
68-
state (dict): The current state of the graph. The input keys will be used
69-
to fetch the correct data from the state.
70-
71-
Returns:
72-
dict: The updated state with the output key containing the generated answer.
73-
74-
Raises:
75-
KeyError: If the input keys are not found in the state, indicating
76-
that the necessary information for generating an answer is missing.
77-
"""
78-
7956
self.logger.info(f"--- Executing {self.node_name} Node ---")
8057

81-
# Interpret input keys based on the provided input expression
8258
input_keys = self.get_input_keys(state)
83-
# Fetching data from the state based on the input keys
84-
input_data = [state[key] for key in input_keys]
85-
user_prompt = input_data[0]
86-
doc = input_data[1]
87-
88-
# Initialize the output parser
89-
if self.node_config.get("schema", None) is not None:
90-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
91-
else:
92-
output_parser = JsonOutputParser()
59+
user_prompt, doc = [state[key] for key in input_keys]
9360

61+
schema = self.node_config.get("schema")
62+
output_parser = JsonOutputParser(pydantic_object=schema) if schema else JsonOutputParser()
9463
format_instructions = output_parser.get_format_instructions()
9564

96-
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
97-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
98-
template_chunks_prompt = TEMPLATE_CHUNKS_MD
99-
template_merge_prompt = TEMPLATE_MERGE_MD
100-
else:
101-
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
102-
template_chunks_prompt = TEMPLATE_CHUNKS
103-
template_merge_prompt = TEMPLATE_MERGE
104-
105-
if self.additional_info is not None:
106-
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
107-
template_chunks_prompt = self.additional_info + template_chunks_prompt
108-
template_merge_prompt = self.additional_info + template_merge_prompt
109-
110-
if len(doc) == 1:
111-
prompt = PromptTemplate(
112-
template=template_no_chunks_prompt ,
113-
input_variables=["question"],
114-
partial_variables={"context": doc,
115-
"format_instructions": format_instructions})
116-
chain = prompt | self.llm_model | output_parser
117-
answer = chain.invoke({"question": user_prompt})
118-
119-
state.update({self.output[0]: answer})
65+
if isinstance(self.llm_model, ChatOpenAI) and (not self.script_creator or self.force) or self.is_md_scraper:
66+
templates = {
67+
'no_chunks': TEMPLATE_NO_CHUNKS_MD,
68+
'chunks': TEMPLATE_CHUNKS_MD,
69+
'merge': TEMPLATE_MERGE_MD
70+
}
71+
72+
url = "https://api.openai.com/v1/chat/completions"
73+
headers = {
74+
"Content-Type": "application/json",
75+
"Authorization": f"Bearer {self.api_key}"
76+
}
77+
78+
if len(doc) == 1:
79+
prompt = templates['no_chunks'].format(
80+
question=user_prompt,
81+
context=doc[0],
82+
format_instructions=format_instructions
83+
)
84+
response = requests.post(url, headers=headers, json={
85+
"model": self.llm_model.model_name,
86+
"messages": [{"role": "user", "content": prompt}],
87+
"temperature": 0
88+
}, timeout=10)
89+
response_text = response.json()['choices'][0]['message']['content']
90+
cleaned_response = json.loads(response_text.replace('\\n', '').replace('\\', ''))
91+
state.update({self.output[0]: cleaned_response})
92+
return state
93+
94+
chunks_responses = []
95+
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
96+
prompt = templates['chunks'].format(
97+
question=user_prompt,
98+
context=chunk,
99+
chunk_id=i + 1,
100+
format_instructions=format_instructions
101+
)
102+
response = requests.post(url, headers=headers, json={
103+
"model": self.llm_model.model_name,
104+
"messages": [{"role": "user", "content": prompt}],
105+
"temperature": 0
106+
}, timeout=10)
107+
chunk_response = response.json()['choices'][0]['message']['content']
108+
cleaned_chunk_response = json.loads(chunk_response.replace('\\n', '').replace('\\', ''))
109+
chunks_responses.append(cleaned_chunk_response)
110+
111+
merge_context = " ".join([json.dumps(chunk) for chunk in chunks_responses])
112+
merge_prompt = templates['merge'].format(
113+
question=user_prompt,
114+
context=merge_context,
115+
format_instructions=format_instructions
116+
)
117+
response = requests.post(url, headers=headers, json={
118+
"model": self.llm_model.model_name,
119+
"messages": [{"role": "user", "content": merge_prompt}],
120+
"temperature": 0
121+
}, timeout=10)
122+
response_text = response.json()['choices'][0]['message']['content']
123+
cleaned_response = json.loads(response_text.replace('\\n', '').replace('\\', ''))
124+
state.update({self.output[0]: cleaned_response})
120125
return state
121126

122-
chains_dict = {}
123-
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
124-
125-
prompt = PromptTemplate(
126-
template=TEMPLATE_CHUNKS,
127-
input_variables=["question"],
128-
partial_variables={"context": chunk,
129-
"chunk_id": i + 1,
130-
"format_instructions": format_instructions})
131-
chain_name = f"chunk{i+1}"
132-
chains_dict[chain_name] = prompt | self.llm_model | output_parser
133-
134-
async_runner = RunnableParallel(**chains_dict)
135-
136-
batch_results = async_runner.invoke({"question": user_prompt})
137-
138-
merge_prompt = PromptTemplate(
139-
template = template_merge_prompt ,
127+
else:
128+
templates = {
129+
'no_chunks': TEMPLATE_NO_CHUNKS,
130+
'chunks': TEMPLATE_CHUNKS,
131+
'merge': TEMPLATE_MERGE
132+
}
133+
134+
if self.additional_info:
135+
templates = {key: self.additional_info + template for key, template in templates.items()}
136+
137+
if len(doc) == 1:
138+
prompt = PromptTemplate(
139+
template=templates['no_chunks'],
140+
input_variables=["question"],
141+
partial_variables={"context": doc, "format_instructions": format_instructions}
142+
)
143+
chain = prompt | self.llm_model | output_parser
144+
answer = chain.invoke({"question": user_prompt})
145+
state.update({self.output[0]: answer})
146+
return state
147+
148+
chains_dict = {}
149+
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
150+
prompt = PromptTemplate(
151+
template=templates['chunks'],
152+
input_variables=["question"],
153+
partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions}
154+
)
155+
chain_name = f"chunk{i+1}"
156+
chains_dict[chain_name] = prompt | self.llm_model | output_parser
157+
158+
async_runner = RunnableParallel(**chains_dict)
159+
batch_results = async_runner.invoke({"question": user_prompt})
160+
161+
merge_prompt = PromptTemplate(
162+
template=templates['merge'],
140163
input_variables=["context", "question"],
141-
partial_variables={"format_instructions": format_instructions},
164+
partial_variables={"format_instructions": format_instructions}
142165
)
166+
merge_chain = merge_prompt | self.llm_model | output_parser
167+
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
143168

144-
merge_chain = merge_prompt | self.llm_model | output_parser
145-
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
146-
147-
state.update({self.output[0]: answer})
148-
return state
169+
state.update({self.output[0]: answer})
170+
return state

tests/graphs/smart_scraper_openai_test.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def graph_config():
1818
return {
1919
"llm": {
2020
"api_key": openai_key,
21-
"model": "gpt-3.5-turbo",
21+
"model": "gpt-4o",
2222
},
2323
"verbose": True,
2424
"headless": False,
@@ -34,19 +34,4 @@ def test_scraping_pipeline(graph_config):
3434

3535
result = smart_scraper_graph.run()
3636

37-
assert result is not None
38-
assert isinstance(result, dict)
39-
40-
def test_get_execution_info(graph_config):
41-
"""Get the execution info"""
42-
smart_scraper_graph = SmartScraperGraph(
43-
prompt="List me all the projects with their description.",
44-
source="https://perinim.github.io/projects/",
45-
config=graph_config,
46-
)
47-
48-
smart_scraper_graph.run()
49-
50-
graph_exec_info = smart_scraper_graph.get_execution_info()
51-
52-
assert graph_exec_info is not None
37+
assert result is not None

0 commit comments

Comments
 (0)