Skip to content

Generate answer parallel #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""
GenerateAnswerNode Module
"""

import asyncio
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import AsyncRunnable
from tqdm import tqdm
from ..utils.merge_results import merge_results
from ..utils.logging import get_logger
from ..models import Ollama, OpenAI
from .base_node import BaseNode
from ..helpers import template_chunks, template_no_chunks, template_merge, template_chunks_md, template_no_chunks_md, template_merge_md
from ..helpers import (
template_chunks, template_no_chunks, template_merge,
template_chunks_md, template_no_chunks_md, template_merge_md
)

class GenerateAnswerNode(BaseNode):
"""
Expand Down Expand Up @@ -38,12 +42,9 @@ def __init__(
node_name: str = "GenerateAnswer",
):
super().__init__(node_name, "node", input, output, 2, node_config)

self.llm_model = node_config["llm_model"]

if isinstance(node_config["llm_model"], Ollama):
self.llm_model.format="json"

self.verbose = (
True if node_config is None else node_config.get("verbose", False)
)
Expand Down Expand Up @@ -89,7 +90,7 @@ def execute(self, state: dict) -> dict:

format_instructions = output_parser.get_format_instructions()

if isinstance(self.llm_model, OpenAI) and not self.script_creator or self.force and not self.script_creator:
if isinstance(self.llm_model, OpenAI) and not self.script_creator or self.force and not self.script_creator:
template_no_chunks_prompt = template_no_chunks_md
template_chunks_prompt = template_chunks_md
template_merge_prompt = template_merge_md
Expand All @@ -99,44 +100,48 @@ def execute(self, state: dict) -> dict:
template_merge_prompt = template_merge

chains_dict = {}
answers = []

# Use tqdm to add progress bar
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
if len(doc) == 1:
# No batching needed for single chunk
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why checking the document length inside the for-loop and not outside of it?

rather, the for-loop should be in place only if the document has multiple chunks

prompt = PromptTemplate(
template=template_no_chunks_prompt,
template=template_no_chunks,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})

else:
# Prepare prompt with chunk information
prompt = PromptTemplate(
template=template_chunks_prompt,
template=template_chunks,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
"chunk_id": i + 1,
"format_instructions": format_instructions})

# Dynamically name the chains based on their index
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Add chain to dictionary with dynamic name
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser

# Batch process chains if there are multiple chunks
if len(chains_dict) > 1:
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
map_chain = RunnableParallel(**chains_dict)
# Chain
answer = map_chain.invoke({"question": user_prompt})
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
template = template_merge_prompt,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": answer, "question": user_prompt})

# Update the state with the generated answer
state.update({self.output[0]: answer})
async def process_chains():
async_runner = AsyncRunnable()
for chain_name, chain in chains_dict.items():
async_runner.add(chain.abatch([{"question": user_prompt}] * len(doc)))
batch_results = await async_runner.run()
return batch_results

# Run asynchronous batch processing and get results
loop = asyncio.get_event_loop()
batch_answers = loop.run_until_complete(process_chains())

# Merge batch results (assuming same structure)
merged_answer = merge_results(answers, batch_answers)
answers = merged_answer

state.update({self.output[0]: answers})
return state
1 change: 1 addition & 0 deletions scrapegraphai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .cleanup_html import cleanup_html
from .logging import *
from .convert_to_md import convert_to_md
from .merge_results import merge_results
31 changes: 31 additions & 0 deletions scrapegraphai/utils/merge_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
def merge_results(answers, batch_answers):
"""
Merges the results from single-chunk processing and batch processing, and adds separators between the chunks.

Parameters:
-----------
answers : list of str
A list of strings containing the results from single-chunk processing.

batch_answers : list of dict
A list of dictionaries, where each dictionary contains a key "text" with the batch processing result as a string.

Returns:
--------
str
A single string containing all merged results, with each result separated by a newline character.

Example:
--------
>>> answers = ["Result from single-chunk 1", "Result from single-chunk 2"]
>>> batch_answers = [{"text": "Result from batch 1"}, {"text": "Result from batch 2"}]
>>> merge_results(answers, batch_answers)
'Result from single-chunk 1\nResult from single-chunk 2\nResult from batch 1\nResult from batch 2'
"""
# Combine answers from single-chunk processing and batch processing
merged_answers = answers + [answer["text"] for answer in batch_answers]

# Add separators between chunks
merged_answers = "\n".join(merged_answers)

return merged_answers
Loading