diff --git a/pyproject.toml b/pyproject.toml index 9fbc763d..8611e778 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,19 +3,15 @@ name = "scrapegraphai" version = "1.11.0b1" - - description = "A web scraping library based on LangChain which uses LLM and direct graph logic to create scraping pipelines." authors = [ { name = "Marco Vinciguerra", email = "mvincig11@gmail.com" }, { name = "Marco Perini", email = "perinim.98@gmail.com" }, { name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" } ] + dependencies = [ "langchain>=0.2.10", - - "langchain-fireworks>=0.1.3", - "langchain_community>=0.2.9", "langchain-google-genai>=1.0.7", "langchain-google-vertexai", "langchain-openai>=0.1.17", @@ -37,6 +33,8 @@ dependencies = [ "google>=3.0.0", "undetected-playwright>=0.3.0", "semchunk>=1.0.1", + "langchain-fireworks>=0.1.3", + "langchain-community>=0.2.9" ] license = "MIT" diff --git a/requirements.txt b/requirements.txt index 124840e5..440bf78a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ langchain>=0.2.10 -langchain_community>=0.2.9 langchain-google-genai>=1.0.7 -langchain-fireworks>=0.1.3 langchain-google-vertexai langchain-openai>=0.1.17 langchain-groq>=0.1.3 @@ -22,4 +20,5 @@ playwright>=1.43.0 google>=3.0.0 undetected-playwright>=0.3.0 semchunk>=1.0.1 - +langchain-fireworks>=0.1.3 +langchain-community>=0.2.9 diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index f764e58b..81812598 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -1,7 +1,7 @@ """ GenerateAnswerNode Module """ - +import asyncio from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser @@ -107,44 +107,43 @@ def execute(self, state: dict) -> dict: template_chunks_prompt = self.additional_info + template_chunks_prompt template_merge_prompt = self.additional_info + template_merge_prompt - chains_dict = {} + if len(doc) == 1: + prompt = PromptTemplate( + template=template_no_chunks_prompt, + input_variables=["question"], + partial_variables={"context": doc, + "format_instructions": format_instructions}) + chain = prompt | self.llm_model | output_parser + answer = chain.invoke({"question": user_prompt}) + + state.update({self.output[0]: answer}) + return state - # Use tqdm to add progress bar + chains_dict = {} for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)): - if len(doc) == 1: - prompt = PromptTemplate( - template=template_no_chunks_prompt, - input_variables=["question"], - partial_variables={"context": chunk, - "format_instructions": format_instructions}) - chain = prompt | self.llm_model | output_parser - answer = chain.invoke({"question": user_prompt}) - break prompt = PromptTemplate( - template=template_chunks_prompt, - input_variables=["question"], - partial_variables={"context": chunk, - "chunk_id": i + 1, - "format_instructions": format_instructions}) - # Dynamically name the chains based on their index + template=template_chunks, + input_variables=["question"], + partial_variables={"context": chunk, + "chunk_id": i + 1, + "format_instructions": format_instructions}) + # Add chain to dictionary with dynamic name chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser - if len(chains_dict) > 1: - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # Chain - answer = map_chain.invoke({"question": user_prompt}) - # Merge the answers from the chunks - merge_prompt = PromptTemplate( + async_runner = RunnableParallel(**chains_dict) + + batch_results = async_runner.invoke({"question": user_prompt}) + + merge_prompt = PromptTemplate( template = template_merge_prompt, input_variables=["context", "question"], partial_variables={"format_instructions": format_instructions}, ) - merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) - # Update the state with the generated answer + merge_chain = merge_prompt | self.llm_model | output_parser + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) + state.update({self.output[0]: answer}) return state