Skip to content

Commit 55199e8

Browse files
committed
add first iterations of the nodes
1 parent ea27b24 commit 55199e8

File tree

5 files changed

+164
-16
lines changed

5 files changed

+164
-16
lines changed

scrapegraphai/nodes/description_node.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
DescriptionNode Module
33
"""
44
from typing import List, Optional
5+
from tqdm import tqdm
6+
from langchain.prompts import PromptTemplate
7+
from langchain_core.runnables import RunnableParallel
58
from .base_node import BaseNode
9+
from ..prompts.description_node_prompts import DESCRIPTION_NODE_PROMPT
610

711
class DescriptionNode(BaseNode):
812
"""
@@ -39,4 +43,32 @@ def __init__(
3943
self.cache_path = node_config.get("cache_path", False)
4044

4145
def execute(self, state: dict) -> dict:
42-
pass
46+
self.logger.info(f"--- Executing {self.node_name} Node ---")
47+
48+
input_keys = self.get_input_keys(state)
49+
input_data = [state[key] for key in input_keys]
50+
docs = input_data[1]
51+
52+
chains_dict = {}
53+
54+
for i, chunk in enumerate(tqdm(docs, desc="Processing chunks", disable=not self.verbose)):
55+
prompt = PromptTemplate(
56+
template=DESCRIPTION_NODE_PROMPT,
57+
partial_variables={"context": chunk,
58+
"chunk_id": i + 1
59+
}
60+
)
61+
chain_name = f"chunk{i+1}"
62+
chains_dict[chain_name] = prompt | self.llm_model
63+
64+
async_runner = RunnableParallel(**chains_dict)
65+
batch_results = async_runner.invoke()
66+
67+
temp_res = {}
68+
69+
for i, (summary, document) in enumerate(zip(batch_results, docs)):
70+
temp_res[summary] = document
71+
72+
state["descriptions"] = temp_res
73+
74+
return state

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
generate_answer_node module
3+
"""
14
from typing import List, Optional
25
from langchain.prompts import PromptTemplate
36
from langchain_core.output_parsers import JsonOutputParser
@@ -15,6 +18,26 @@
1518
)
1619

1720
class GenerateAnswerNode(BaseNode):
21+
"""
22+
Initializes the GenerateAnswerNode class.
23+
24+
Args:
25+
input (str): The input data type for the node.
26+
output (List[str]): The output data type(s) for the node.
27+
node_config (Optional[dict]): Configuration dictionary for the node,
28+
which includes the LLM model, verbosity, schema, and other settings.
29+
Defaults to None.
30+
node_name (str): The name of the node. Defaults to "GenerateAnswer".
31+
32+
Attributes:
33+
llm_model: The language model specified in the node configuration.
34+
verbose (bool): Whether verbose mode is enabled.
35+
force (bool): Whether to force certain behaviors, overriding defaults.
36+
script_creator (bool): Whether the node is in script creation mode.
37+
is_md_scraper (bool): Whether the node is scraping markdown data.
38+
additional_info (Optional[str]): Any additional information to be
39+
included in the prompt templates.
40+
"""
1841
def __init__(
1942
self,
2043
input: str,
@@ -100,7 +123,9 @@ def execute(self, state: dict) -> dict:
100123
prompt = PromptTemplate(
101124
template=template_chunks_prompt,
102125
input_variables=["question"],
103-
partial_variables={"context": chunk, "chunk_id": i + 1, "format_instructions": format_instructions}
126+
partial_variables={"context": chunk,
127+
"chunk_id": i + 1,
128+
"format_instructions": format_instructions}
104129
)
105130
chain_name = f"chunk{i+1}"
106131
chains_dict[chain_name] = prompt | self.llm_model

scrapegraphai/nodes/generate_answer_node_k_level.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,19 @@
22
GenerateAnswerNodeKLevel Module
33
"""
44
from typing import List, Optional
5+
from langchain.prompts import PromptTemplate
6+
from tqdm import tqdm
7+
from langchain_core.output_parsers import JsonOutputParser
8+
from langchain_core.runnables import RunnableParallel
9+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
10+
from langchain_mistralai import ChatMistralAI
11+
from langchain_aws import ChatBedrock
12+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
513
from .base_node import BaseNode
14+
from ..prompts import (
15+
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
16+
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
17+
)
618

719
class GenerateAnswerNodeKLevel(BaseNode):
820
"""
@@ -33,18 +45,92 @@ def __init__(
3345

3446
self.llm_model = node_config["llm_model"]
3547
self.embedder_model = node_config.get("embedder_model", None)
36-
self.verbose = (
37-
False if node_config is None else node_config.get("verbose", False)
38-
)
48+
self.verbose = node_config.get("verbose", False)
49+
self.force = node_config.get("force", False)
50+
self.script_creator = node_config.get("script_creator", False)
51+
self.is_md_scraper = node_config.get("is_md_scraper", False)
52+
self.additional_info = node_config.get("additional_info")
3953

4054
def execute(self, state: dict) -> dict:
55+
input_keys = self.get_input_keys(state)
56+
input_data = [state[key] for key in input_keys]
57+
user_prompt = input_data[0]
58+
59+
if self.node_config.get("schema", None) is not None:
60+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
61+
self.llm_model = self.llm_model.with_structured_output(
62+
schema=self.node_config["schema"]
63+
)
64+
output_parser = get_structured_output_parser(self.node_config["schema"])
65+
format_instructions = "NA"
66+
else:
67+
if not isinstance(self.llm_model, ChatBedrock):
68+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
69+
format_instructions = output_parser.get_format_instructions()
70+
else:
71+
output_parser = None
72+
format_instructions = ""
73+
else:
74+
if not isinstance(self.llm_model, ChatBedrock):
75+
output_parser = JsonOutputParser()
76+
format_instructions = output_parser.get_format_instructions()
77+
else:
78+
output_parser = None
79+
format_instructions = ""
80+
81+
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \
82+
and not self.script_creator \
83+
or self.force \
84+
and not self.script_creator or self.is_md_scraper:
85+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
86+
template_chunks_prompt = TEMPLATE_CHUNKS_MD
87+
template_merge_prompt = TEMPLATE_MERGE_MD
88+
else:
89+
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS
90+
template_chunks_prompt = TEMPLATE_CHUNKS
91+
template_merge_prompt = TEMPLATE_MERGE
92+
93+
if self.additional_info is not None:
94+
template_no_chunks_prompt = self.additional_info + template_no_chunks_prompt
95+
template_chunks_prompt = self.additional_info + template_chunks_prompt
96+
template_merge_prompt = self.additional_info + template_merge_prompt
97+
4198
client = state["vectorial_db"]
4299

43-
answer = client.query(
44-
collection_name="demo_collection",
45-
query_text="This is a query document"
100+
answer_db = client.query(
101+
collection_name="vectorial_collection",
102+
query_text= state["question"]
46103
)
47104

105+
results_db = [elem for elem in state[answer_db]]
106+
107+
chains_dict = {}
108+
for i, chunk in enumerate(tqdm(results_db,
109+
desc="Processing chunks", disable=not self.verbose)):
110+
prompt = PromptTemplate(
111+
template=template_chunks_prompt,
112+
input_variables=["question"],
113+
partial_variables={"context": chunk,
114+
"chunk_id": i + 1,
115+
}
116+
)
117+
chain_name = f"chunk{i+1}"
118+
chains_dict[chain_name] = prompt | self.llm_model
119+
120+
async_runner = RunnableParallel(**chains_dict)
121+
batch_results = async_runner.invoke({"question": user_prompt})
122+
123+
merge_prompt = PromptTemplate(
124+
template=template_merge_prompt,
125+
input_variables=["context", "question"],
126+
partial_variables={"format_instructions": format_instructions}
127+
)
128+
129+
merge_chain = merge_prompt | self.llm_model
130+
if output_parser:
131+
merge_chain = merge_chain | output_parser
132+
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
133+
48134
state["answer"] = answer
49135

50136
return state

scrapegraphai/nodes/rag_node.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,13 @@ def execute(self, state: dict) -> dict:
4949
else:
5050
raise ValueError("client_type provided not correct")
5151

52-
docs = ["Qdrant has Langchain integrations", "Qdrant also has Llama Index integrations"]
53-
metadata = [
54-
{"source": "Langchain-docs"},
55-
{"source": "Linkedin-docs"},
56-
]
57-
ids = [42, 2]
52+
docs = [elem for elem in state.get("descriptions").keys()]
53+
metadata = []
5854

5955
client.add(
60-
collection_name="demo_collection",
56+
collection_name="vectorial_collection",
6157
documents=docs,
6258
metadata=metadata,
63-
ids=ids
6459
)
6560

6661
state["vectorial_db"] = client
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
description node prompts
3+
"""
4+
5+
DESCRIPTION_NODE_PROMPT = """
6+
You are a scraper and you have just scraped the
7+
following content from a website. \n
8+
Please provide a description summary of maximum of 10 words
9+
Content of the website: {content}
10+
"""

0 commit comments

Comments
 (0)