Skip to content

Update generate answer node #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ authors = [
{ name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" }
]
dependencies = [
"langchain-community==0.2.9",
"langchain>=0.2.10",
"langchain-google-genai>=1.0.7",
"langchain-google-vertexai",
Expand Down Expand Up @@ -92,4 +93,4 @@ dev-dependencies = [
[tool.rye.scripts]
pylint-local = "pylint scrapegraphai/**/*.py"
pylint-ci = "pylint --disable=C0114,C0115,C0116 --exit-zero scrapegraphai/**/*.py"
update-requirements = "python 'manual deployment/autorequirements.py'"
update-requirements = "python 'manual deployment/autorequirements.py'"
18 changes: 17 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ aiofiles==24.1.0
# via burr
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
Expand Down Expand Up @@ -73,6 +74,8 @@ contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dataclasses-json==0.6.7
# via langchain-community
defusedxml==0.7.1
# via langchain-anthropic
dill==0.3.8
Expand Down Expand Up @@ -177,7 +180,6 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
Expand Down Expand Up @@ -249,15 +251,19 @@ jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
# via matplotlib
langchain==0.2.10
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.20
# via scrapegraphai
langchain-aws==0.1.12
# via scrapegraphai
langchain-community==0.2.9
# via scrapegraphai
langchain-core==0.2.22
# via langchain
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-google-vertexai
Expand All @@ -281,6 +287,7 @@ langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.93
# via langchain
# via langchain-community
# via langchain-core
loguru==0.7.2
# via burr
Expand All @@ -290,6 +297,8 @@ markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.3
# via dataclasses-json
matplotlib==3.9.1
# via burr
mccabe==0.7.0
Expand All @@ -313,6 +322,7 @@ numpy==1.26.4
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via matplotlib
# via pandas
# via pyarrow
Expand All @@ -333,6 +343,7 @@ packaging==24.1
# via google-cloud-bigquery
# via huggingface-hub
# via langchain-core
# via marshmallow
# via matplotlib
# via pytest
# via sphinx
Expand Down Expand Up @@ -423,6 +434,7 @@ pytz==2024.1
pyyaml==6.0.1
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-core
# via uvicorn
referencing==0.35.1
Expand All @@ -438,6 +450,7 @@ requests==2.32.3
# via google-cloud-storage
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via sphinx
Expand Down Expand Up @@ -495,12 +508,14 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sqlalchemy==2.0.31
# via langchain
# via langchain-community
starlette==0.37.2
# via fastapi
streamlit==1.36.0
# via burr
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
# via streamlit
tiktoken==0.7.0
Expand Down Expand Up @@ -551,6 +566,7 @@ typing-extensions==4.12.2
# via typing-inspect
# via uvicorn
typing-inspect==0.9.0
# via dataclasses-json
# via sf-hamilton
tzdata==2024.1
# via pandas
Expand Down
22 changes: 21 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
-e file:.
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
Expand Down Expand Up @@ -44,6 +45,8 @@ certifi==2024.7.4
# via requests
charset-normalizer==3.3.2
# via requests
dataclasses-json==0.6.7
# via langchain-community
defusedxml==0.7.1
# via langchain-anthropic
dill==0.3.8
Expand Down Expand Up @@ -125,7 +128,6 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
Expand Down Expand Up @@ -170,15 +172,19 @@ jsonpatch==1.33
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.10
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.20
# via scrapegraphai
langchain-aws==0.1.12
# via scrapegraphai
langchain-community==0.2.9
# via scrapegraphai
langchain-core==0.2.22
# via langchain
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-google-vertexai
Expand All @@ -202,9 +208,12 @@ langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.93
# via langchain
# via langchain-community
# via langchain-core
lxml==5.2.2
# via free-proxy
marshmallow==3.21.3
# via dataclasses-json
minify-html==0.15.0
# via scrapegraphai
mpire==2.10.2
Expand All @@ -214,10 +223,13 @@ multidict==6.0.5
# via yarl
multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0
# via typing-inspect
numpy==1.26.4
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via pandas
# via shapely
openai==1.37.0
Expand All @@ -231,6 +243,7 @@ packaging==24.1
# via google-cloud-bigquery
# via huggingface-hub
# via langchain-core
# via marshmallow
pandas==2.2.2
# via scrapegraphai
pillow==10.4.0
Expand Down Expand Up @@ -288,6 +301,7 @@ pytz==2024.1
pyyaml==6.0.1
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-core
regex==2024.5.15
# via tiktoken
Expand All @@ -298,6 +312,7 @@ requests==2.32.3
# via google-cloud-storage
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via tiktoken
Expand All @@ -321,8 +336,10 @@ soupsieve==2.5
# via beautifulsoup4
sqlalchemy==2.0.31
# via langchain
# via langchain-community
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
tiktoken==0.7.0
# via langchain-openai
Expand All @@ -347,6 +364,9 @@ typing-extensions==4.12.2
# via pydantic-core
# via pyee
# via sqlalchemy
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1
# via pandas
undetected-playwright==0.3.0
Expand Down
53 changes: 25 additions & 28 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,24 @@ def execute(self, state):

chains_dict = {}

# Use tqdm to add progress bar
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_csv_prompt,
input_variables=["question"],
partial_variables={
"context": doc,
"format_instructions": format_instructions,
},
)

chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
state.update({self.output[0]: answer})
return state

for i, chunk in enumerate(
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_csv_prompt,
input_variables=["question"],
partial_variables={
"context": chunk,
"format_instructions": format_instructions,
},
)

chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
break

prompt = PromptTemplate(
template=template_chunks_csv_prompt,
input_variables=["question"],
Expand All @@ -144,24 +144,21 @@ def execute(self, state):
},
)

# Dynamically name the chains based on their index
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser

if len(chains_dict) > 1:
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
map_chain = RunnableParallel(**chains_dict)
# Chain
answer = map_chain.invoke({"question": user_prompt})
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
template=template_merge_csv_prompt,
async_runner = RunnableParallel(**chains_dict)

batch_results = async_runner.invoke({"question": user_prompt})

merge_prompt = PromptTemplate(
template = template_merge_csv_prompt,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": answer, "question": user_prompt})

# Update the state with the generated answer
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})

state.update({self.output[0]: answer})
return state
return state
Loading
Loading