From 70911f83bea56ae16ce2a052c1b34cc6611e5cf7 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 23 Jul 2024 11:07:35 +0200 Subject: [PATCH 1/2] Update generate_answer_pdf_node.py --- .../nodes/generate_answer_pdf_node.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index db6152bc..69f97f08 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -114,24 +114,26 @@ def execute(self, state): format_instructions = output_parser.get_format_instructions() + if len(doc) == 1: + prompt = PromptTemplate( + template=template_no_chunks_pdf_prompt, + input_variables=["question"], + partial_variables={ + "context":chunk, + "format_instructions": format_instructions, + }, + ) + chain = prompt | self.llm_model | output_parser + answer = chain.invoke({"question": user_prompt}) + + + state.update({self.output[0]: answer}) + return state + chains_dict = {} # Use tqdm to add progress bar for i, chunk in enumerate( - tqdm(doc, desc="Processing chunks", disable=not self.verbose) - ): - if len(doc) == 1: - prompt = PromptTemplate( - template=template_no_chunks_pdf_prompt, - input_variables=["question"], - partial_variables={ - "context":chunk, - "format_instructions": format_instructions, - }, - ) - chain = prompt | self.llm_model | output_parser - answer = chain.invoke({"question": user_prompt}) - - break + tqdm(doc, desc="Processing chunks", disable=not self.verbose)): prompt = PromptTemplate( template=template_chunks_pdf_prompt, input_variables=["question"], @@ -146,20 +148,18 @@ def execute(self, state): chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser - if len(chains_dict) > 1: - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # Chain - answer = map_chain.invoke({"question": user_prompt}) - # Merge the answers from the chunks - merge_prompt = PromptTemplate( - template=template_merge_pdf_prompt, - input_variables=["context", "question"], - partial_variables={"format_instructions": format_instructions}, - ) - merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) + # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel + map_chain = RunnableParallel(**chains_dict) + # Chain + answer = map_chain.ainvoke({"question": user_prompt}) + # Merge the answers from the chunks + merge_prompt = PromptTemplate( + template=template_merge_pdf_prompt, + input_variables=["context", "question"], + partial_variables={"format_instructions": format_instructions}, + ) + merge_chain = merge_prompt | self.llm_model | output_parser + answer = merge_chain.invoke({"context": answer, "question": user_prompt}) - # Update the state with the generated answer state.update({self.output[0]: answer}) return state From 8dd941e39e7602b3fcfd8f411951263487c98820 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 23 Jul 2024 12:40:11 +0200 Subject: [PATCH 2/2] update generate answer nodes Co-Authored-By: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> --- pyproject.toml | 3 +- requirements-dev.lock | 18 +++++- requirements.lock | 22 ++++++- .../nodes/generate_answer_csv_node.py | 53 ++++++++--------- .../nodes/generate_answer_omni_node.py | 59 +++++++++---------- .../nodes/generate_answer_pdf_node.py | 23 ++++---- 6 files changed, 103 insertions(+), 75 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 10e8b61f..122ec328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ authors = [ { name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" } ] dependencies = [ + "langchain-community==0.2.9", "langchain>=0.2.10", "langchain-google-genai>=1.0.7", "langchain-google-vertexai", @@ -92,4 +93,4 @@ dev-dependencies = [ [tool.rye.scripts] pylint-local = "pylint scrapegraphai/**/*.py" pylint-ci = "pylint --disable=C0114,C0115,C0116 --exit-zero scrapegraphai/**/*.py" -update-requirements = "python 'manual deployment/autorequirements.py'" +update-requirements = "python 'manual deployment/autorequirements.py'" \ No newline at end of file diff --git a/requirements-dev.lock b/requirements-dev.lock index 3e8ddc74..2c56f3db 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,6 +12,7 @@ aiofiles==24.1.0 # via burr aiohttp==3.9.5 # via langchain + # via langchain-community # via langchain-fireworks # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 @@ -73,6 +74,8 @@ contourpy==1.2.1 # via matplotlib cycler==0.12.1 # via matplotlib +dataclasses-json==0.6.7 + # via langchain-community defusedxml==0.7.1 # via langchain-anthropic dill==0.3.8 @@ -177,7 +180,6 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright - # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -249,15 +251,19 @@ jsonschema-specifications==2023.12.1 kiwisolver==1.4.5 # via matplotlib langchain==0.2.10 + # via langchain-community # via scrapegraphai langchain-anthropic==0.1.20 # via scrapegraphai langchain-aws==0.1.12 # via scrapegraphai +langchain-community==0.2.9 + # via scrapegraphai langchain-core==0.2.22 # via langchain # via langchain-anthropic # via langchain-aws + # via langchain-community # via langchain-fireworks # via langchain-google-genai # via langchain-google-vertexai @@ -281,6 +287,7 @@ langchain-text-splitters==0.2.2 # via langchain langsmith==0.1.93 # via langchain + # via langchain-community # via langchain-core loguru==0.7.2 # via burr @@ -290,6 +297,8 @@ markdown-it-py==3.0.0 # via rich markupsafe==2.1.5 # via jinja2 +marshmallow==3.21.3 + # via dataclasses-json matplotlib==3.9.1 # via burr mccabe==0.7.0 @@ -313,6 +322,7 @@ numpy==1.26.4 # via faiss-cpu # via langchain # via langchain-aws + # via langchain-community # via matplotlib # via pandas # via pyarrow @@ -333,6 +343,7 @@ packaging==24.1 # via google-cloud-bigquery # via huggingface-hub # via langchain-core + # via marshmallow # via matplotlib # via pytest # via sphinx @@ -423,6 +434,7 @@ pytz==2024.1 pyyaml==6.0.1 # via huggingface-hub # via langchain + # via langchain-community # via langchain-core # via uvicorn referencing==0.35.1 @@ -438,6 +450,7 @@ requests==2.32.3 # via google-cloud-storage # via huggingface-hub # via langchain + # via langchain-community # via langchain-fireworks # via langsmith # via sphinx @@ -495,12 +508,14 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy==2.0.31 # via langchain + # via langchain-community starlette==0.37.2 # via fastapi streamlit==1.36.0 # via burr tenacity==8.5.0 # via langchain + # via langchain-community # via langchain-core # via streamlit tiktoken==0.7.0 @@ -551,6 +566,7 @@ typing-extensions==4.12.2 # via typing-inspect # via uvicorn typing-inspect==0.9.0 + # via dataclasses-json # via sf-hamilton tzdata==2024.1 # via pandas diff --git a/requirements.lock b/requirements.lock index d99559de..a943dff1 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,6 +10,7 @@ -e file:. aiohttp==3.9.5 # via langchain + # via langchain-community # via langchain-fireworks # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 @@ -44,6 +45,8 @@ certifi==2024.7.4 # via requests charset-normalizer==3.3.2 # via requests +dataclasses-json==0.6.7 + # via langchain-community defusedxml==0.7.1 # via langchain-anthropic dill==0.3.8 @@ -125,7 +128,6 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright - # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -170,15 +172,19 @@ jsonpatch==1.33 jsonpointer==3.0.0 # via jsonpatch langchain==0.2.10 + # via langchain-community # via scrapegraphai langchain-anthropic==0.1.20 # via scrapegraphai langchain-aws==0.1.12 # via scrapegraphai +langchain-community==0.2.9 + # via scrapegraphai langchain-core==0.2.22 # via langchain # via langchain-anthropic # via langchain-aws + # via langchain-community # via langchain-fireworks # via langchain-google-genai # via langchain-google-vertexai @@ -202,9 +208,12 @@ langchain-text-splitters==0.2.2 # via langchain langsmith==0.1.93 # via langchain + # via langchain-community # via langchain-core lxml==5.2.2 # via free-proxy +marshmallow==3.21.3 + # via dataclasses-json minify-html==0.15.0 # via scrapegraphai mpire==2.10.2 @@ -214,10 +223,13 @@ multidict==6.0.5 # via yarl multiprocess==0.70.16 # via mpire +mypy-extensions==1.0.0 + # via typing-inspect numpy==1.26.4 # via faiss-cpu # via langchain # via langchain-aws + # via langchain-community # via pandas # via shapely openai==1.37.0 @@ -231,6 +243,7 @@ packaging==24.1 # via google-cloud-bigquery # via huggingface-hub # via langchain-core + # via marshmallow pandas==2.2.2 # via scrapegraphai pillow==10.4.0 @@ -288,6 +301,7 @@ pytz==2024.1 pyyaml==6.0.1 # via huggingface-hub # via langchain + # via langchain-community # via langchain-core regex==2024.5.15 # via tiktoken @@ -298,6 +312,7 @@ requests==2.32.3 # via google-cloud-storage # via huggingface-hub # via langchain + # via langchain-community # via langchain-fireworks # via langsmith # via tiktoken @@ -321,8 +336,10 @@ soupsieve==2.5 # via beautifulsoup4 sqlalchemy==2.0.31 # via langchain + # via langchain-community tenacity==8.5.0 # via langchain + # via langchain-community # via langchain-core tiktoken==0.7.0 # via langchain-openai @@ -347,6 +364,9 @@ typing-extensions==4.12.2 # via pydantic-core # via pyee # via sqlalchemy + # via typing-inspect +typing-inspect==0.9.0 + # via dataclasses-json tzdata==2024.1 # via pandas undetected-playwright==0.3.0 diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index 43657b50..6ce19ef2 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -116,24 +116,24 @@ def execute(self, state): chains_dict = {} - # Use tqdm to add progress bar + if len(doc) == 1: + prompt = PromptTemplate( + template=template_no_chunks_csv_prompt, + input_variables=["question"], + partial_variables={ + "context": doc, + "format_instructions": format_instructions, + }, + ) + + chain = prompt | self.llm_model | output_parser + answer = chain.invoke({"question": user_prompt}) + state.update({self.output[0]: answer}) + return state + for i, chunk in enumerate( tqdm(doc, desc="Processing chunks", disable=not self.verbose) ): - if len(doc) == 1: - prompt = PromptTemplate( - template=template_no_chunks_csv_prompt, - input_variables=["question"], - partial_variables={ - "context": chunk, - "format_instructions": format_instructions, - }, - ) - - chain = prompt | self.llm_model | output_parser - answer = chain.invoke({"question": user_prompt}) - break - prompt = PromptTemplate( template=template_chunks_csv_prompt, input_variables=["question"], @@ -144,24 +144,21 @@ def execute(self, state): }, ) - # Dynamically name the chains based on their index chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser - if len(chains_dict) > 1: - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # Chain - answer = map_chain.invoke({"question": user_prompt}) - # Merge the answers from the chunks - merge_prompt = PromptTemplate( - template=template_merge_csv_prompt, + async_runner = RunnableParallel(**chains_dict) + + batch_results = async_runner.invoke({"question": user_prompt}) + + merge_prompt = PromptTemplate( + template = template_merge_csv_prompt, input_variables=["context", "question"], partial_variables={"format_instructions": format_instructions}, ) - merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) - # Update the state with the generated answer + merge_chain = merge_prompt | self.llm_model | output_parser + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) + state.update({self.output[0]: answer}) - return state + return state \ No newline at end of file diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index 7a030c6f..c2f2b65d 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -100,26 +100,26 @@ def execute(self, state: dict) -> dict: chains_dict = {} + if len(doc) == 1: + prompt = PromptTemplate( + template=template_no_chunk_omni_prompt, + input_variables=["question"], + partial_variables={ + "context": chunk, + "format_instructions": format_instructions, + "img_desc": imag_desc, + }, + ) + + chain = prompt | self.llm_model | output_parser + answer = chain.invoke({"question": user_prompt}) + + state.update({self.output[0]: answer}) + return state - # Use tqdm to add progress bar for i, chunk in enumerate( tqdm(doc, desc="Processing chunks", disable=not self.verbose) ): - if len(doc) == 1: - prompt = PromptTemplate( - template=template_no_chunk_omni_prompt, - input_variables=["question"], - partial_variables={ - "context": chunk, - "format_instructions": format_instructions, - "img_desc": imag_desc, - }, - ) - - chain = prompt | self.llm_model | output_parser - answer = chain.invoke({"question": user_prompt}) - break - prompt = PromptTemplate( template=template_chunks_omni_prompt, input_variables=["question"], @@ -134,23 +134,18 @@ def execute(self, state: dict) -> dict: chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser - if len(chains_dict) > 1: - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # Chain - answer = map_chain.invoke({"question": user_prompt}) - # Merge the answers from the chunks - merge_prompt = PromptTemplate( - template=template_merge_omni_prompt, + async_runner = RunnableParallel(**chains_dict) + + batch_results = async_runner.invoke({"question": user_prompt}) + + merge_prompt = PromptTemplate( + template = template_merge_omni_prompt, input_variables=["context", "question"], - partial_variables={ - "format_instructions": format_instructions, - "img_desc": imag_desc, - }, + partial_variables={"format_instructions": format_instructions}, ) - merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) - # Update the state with the generated answer + merge_chain = merge_prompt | self.llm_model | output_parser + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) + state.update({self.output[0]: answer}) - return state + return state \ No newline at end of file diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index 69f97f08..7add7948 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -131,7 +131,7 @@ def execute(self, state): return state chains_dict = {} - # Use tqdm to add progress bar + for i, chunk in enumerate( tqdm(doc, desc="Processing chunks", disable=not self.verbose)): prompt = PromptTemplate( @@ -144,22 +144,21 @@ def execute(self, state): }, ) - # Dynamically name the chains based on their index chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel - map_chain = RunnableParallel(**chains_dict) - # Chain - answer = map_chain.ainvoke({"question": user_prompt}) - # Merge the answers from the chunks + async_runner = RunnableParallel(**chains_dict) + + batch_results = async_runner.invoke({"question": user_prompt}) + merge_prompt = PromptTemplate( - template=template_merge_pdf_prompt, - input_variables=["context", "question"], - partial_variables={"format_instructions": format_instructions}, - ) + template = template_merge_pdf_prompt, + input_variables=["context", "question"], + partial_variables={"format_instructions": format_instructions}, + ) + merge_chain = merge_prompt | self.llm_model | output_parser - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) state.update({self.output[0]: answer}) return state