Skip to content

Commit 0ebc387

Browse files
authored
Merge pull request #487 from ScrapeGraphAI/update-generate-answer-node
Update generate answer node
2 parents 255e569 + 8dd941e commit 0ebc387

File tree

6 files changed

+119
-91
lines changed

6 files changed

+119
-91
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ authors = [
1313
{ name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" }
1414
]
1515
dependencies = [
16+
"langchain-community==0.2.9",
1617
"langchain>=0.2.10",
1718
"langchain-fireworks>=0.1.3",
1819
"langchain_community>=0.2.9",
@@ -93,4 +94,4 @@ dev-dependencies = [
9394
[tool.rye.scripts]
9495
pylint-local = "pylint scrapegraphai/**/*.py"
9596
pylint-ci = "pylint --disable=C0114,C0115,C0116 --exit-zero scrapegraphai/**/*.py"
96-
update-requirements = "python 'manual deployment/autorequirements.py'"
97+
update-requirements = "python 'manual deployment/autorequirements.py'"

requirements-dev.lock

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ aiofiles==24.1.0
1212
# via burr
1313
aiohttp==3.9.5
1414
# via langchain
15+
# via langchain-community
1516
# via langchain-fireworks
1617
# via langchain-nvidia-ai-endpoints
1718
aiosignal==1.3.1
@@ -73,6 +74,8 @@ contourpy==1.2.1
7374
# via matplotlib
7475
cycler==0.12.1
7576
# via matplotlib
77+
dataclasses-json==0.6.7
78+
# via langchain-community
7679
defusedxml==0.7.1
7780
# via langchain-anthropic
7881
dill==0.3.8
@@ -177,7 +180,6 @@ graphviz==0.20.3
177180
# via scrapegraphai
178181
greenlet==3.0.3
179182
# via playwright
180-
# via sqlalchemy
181183
groq==0.9.0
182184
# via langchain-groq
183185
grpc-google-iam-v1==0.13.1
@@ -249,15 +251,19 @@ jsonschema-specifications==2023.12.1
249251
kiwisolver==1.4.5
250252
# via matplotlib
251253
langchain==0.2.10
254+
# via langchain-community
252255
# via scrapegraphai
253256
langchain-anthropic==0.1.20
254257
# via scrapegraphai
255258
langchain-aws==0.1.12
256259
# via scrapegraphai
260+
langchain-community==0.2.9
261+
# via scrapegraphai
257262
langchain-core==0.2.22
258263
# via langchain
259264
# via langchain-anthropic
260265
# via langchain-aws
266+
# via langchain-community
261267
# via langchain-fireworks
262268
# via langchain-google-genai
263269
# via langchain-google-vertexai
@@ -281,6 +287,7 @@ langchain-text-splitters==0.2.2
281287
# via langchain
282288
langsmith==0.1.93
283289
# via langchain
290+
# via langchain-community
284291
# via langchain-core
285292
loguru==0.7.2
286293
# via burr
@@ -290,6 +297,8 @@ markdown-it-py==3.0.0
290297
# via rich
291298
markupsafe==2.1.5
292299
# via jinja2
300+
marshmallow==3.21.3
301+
# via dataclasses-json
293302
matplotlib==3.9.1
294303
# via burr
295304
mccabe==0.7.0
@@ -313,6 +322,7 @@ numpy==1.26.4
313322
# via faiss-cpu
314323
# via langchain
315324
# via langchain-aws
325+
# via langchain-community
316326
# via matplotlib
317327
# via pandas
318328
# via pyarrow
@@ -333,6 +343,7 @@ packaging==24.1
333343
# via google-cloud-bigquery
334344
# via huggingface-hub
335345
# via langchain-core
346+
# via marshmallow
336347
# via matplotlib
337348
# via pytest
338349
# via sphinx
@@ -423,6 +434,7 @@ pytz==2024.1
423434
pyyaml==6.0.1
424435
# via huggingface-hub
425436
# via langchain
437+
# via langchain-community
426438
# via langchain-core
427439
# via uvicorn
428440
referencing==0.35.1
@@ -438,6 +450,7 @@ requests==2.32.3
438450
# via google-cloud-storage
439451
# via huggingface-hub
440452
# via langchain
453+
# via langchain-community
441454
# via langchain-fireworks
442455
# via langsmith
443456
# via sphinx
@@ -495,12 +508,14 @@ sphinxcontrib-serializinghtml==1.1.10
495508
# via sphinx
496509
sqlalchemy==2.0.31
497510
# via langchain
511+
# via langchain-community
498512
starlette==0.37.2
499513
# via fastapi
500514
streamlit==1.36.0
501515
# via burr
502516
tenacity==8.5.0
503517
# via langchain
518+
# via langchain-community
504519
# via langchain-core
505520
# via streamlit
506521
tiktoken==0.7.0
@@ -551,6 +566,7 @@ typing-extensions==4.12.2
551566
# via typing-inspect
552567
# via uvicorn
553568
typing-inspect==0.9.0
569+
# via dataclasses-json
554570
# via sf-hamilton
555571
tzdata==2024.1
556572
# via pandas

requirements.lock

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
-e file:.
1111
aiohttp==3.9.5
1212
# via langchain
13+
# via langchain-community
1314
# via langchain-fireworks
1415
# via langchain-nvidia-ai-endpoints
1516
aiosignal==1.3.1
@@ -44,6 +45,8 @@ certifi==2024.7.4
4445
# via requests
4546
charset-normalizer==3.3.2
4647
# via requests
48+
dataclasses-json==0.6.7
49+
# via langchain-community
4750
defusedxml==0.7.1
4851
# via langchain-anthropic
4952
dill==0.3.8
@@ -125,7 +128,6 @@ graphviz==0.20.3
125128
# via scrapegraphai
126129
greenlet==3.0.3
127130
# via playwright
128-
# via sqlalchemy
129131
groq==0.9.0
130132
# via langchain-groq
131133
grpc-google-iam-v1==0.13.1
@@ -170,15 +172,19 @@ jsonpatch==1.33
170172
jsonpointer==3.0.0
171173
# via jsonpatch
172174
langchain==0.2.10
175+
# via langchain-community
173176
# via scrapegraphai
174177
langchain-anthropic==0.1.20
175178
# via scrapegraphai
176179
langchain-aws==0.1.12
177180
# via scrapegraphai
181+
langchain-community==0.2.9
182+
# via scrapegraphai
178183
langchain-core==0.2.22
179184
# via langchain
180185
# via langchain-anthropic
181186
# via langchain-aws
187+
# via langchain-community
182188
# via langchain-fireworks
183189
# via langchain-google-genai
184190
# via langchain-google-vertexai
@@ -202,9 +208,12 @@ langchain-text-splitters==0.2.2
202208
# via langchain
203209
langsmith==0.1.93
204210
# via langchain
211+
# via langchain-community
205212
# via langchain-core
206213
lxml==5.2.2
207214
# via free-proxy
215+
marshmallow==3.21.3
216+
# via dataclasses-json
208217
minify-html==0.15.0
209218
# via scrapegraphai
210219
mpire==2.10.2
@@ -214,10 +223,13 @@ multidict==6.0.5
214223
# via yarl
215224
multiprocess==0.70.16
216225
# via mpire
226+
mypy-extensions==1.0.0
227+
# via typing-inspect
217228
numpy==1.26.4
218229
# via faiss-cpu
219230
# via langchain
220231
# via langchain-aws
232+
# via langchain-community
221233
# via pandas
222234
# via shapely
223235
openai==1.37.0
@@ -231,6 +243,7 @@ packaging==24.1
231243
# via google-cloud-bigquery
232244
# via huggingface-hub
233245
# via langchain-core
246+
# via marshmallow
234247
pandas==2.2.2
235248
# via scrapegraphai
236249
pillow==10.4.0
@@ -288,6 +301,7 @@ pytz==2024.1
288301
pyyaml==6.0.1
289302
# via huggingface-hub
290303
# via langchain
304+
# via langchain-community
291305
# via langchain-core
292306
regex==2024.5.15
293307
# via tiktoken
@@ -298,6 +312,7 @@ requests==2.32.3
298312
# via google-cloud-storage
299313
# via huggingface-hub
300314
# via langchain
315+
# via langchain-community
301316
# via langchain-fireworks
302317
# via langsmith
303318
# via tiktoken
@@ -321,8 +336,10 @@ soupsieve==2.5
321336
# via beautifulsoup4
322337
sqlalchemy==2.0.31
323338
# via langchain
339+
# via langchain-community
324340
tenacity==8.5.0
325341
# via langchain
342+
# via langchain-community
326343
# via langchain-core
327344
tiktoken==0.7.0
328345
# via langchain-openai
@@ -347,6 +364,9 @@ typing-extensions==4.12.2
347364
# via pydantic-core
348365
# via pyee
349366
# via sqlalchemy
367+
# via typing-inspect
368+
typing-inspect==0.9.0
369+
# via dataclasses-json
350370
tzdata==2024.1
351371
# via pandas
352372
undetected-playwright==0.3.0

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,24 @@ def execute(self, state):
116116

117117
chains_dict = {}
118118

119-
# Use tqdm to add progress bar
119+
if len(doc) == 1:
120+
prompt = PromptTemplate(
121+
template=template_no_chunks_csv_prompt,
122+
input_variables=["question"],
123+
partial_variables={
124+
"context": doc,
125+
"format_instructions": format_instructions,
126+
},
127+
)
128+
129+
chain = prompt | self.llm_model | output_parser
130+
answer = chain.invoke({"question": user_prompt})
131+
state.update({self.output[0]: answer})
132+
return state
133+
120134
for i, chunk in enumerate(
121135
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
122136
):
123-
if len(doc) == 1:
124-
prompt = PromptTemplate(
125-
template=template_no_chunks_csv_prompt,
126-
input_variables=["question"],
127-
partial_variables={
128-
"context": chunk,
129-
"format_instructions": format_instructions,
130-
},
131-
)
132-
133-
chain = prompt | self.llm_model | output_parser
134-
answer = chain.invoke({"question": user_prompt})
135-
break
136-
137137
prompt = PromptTemplate(
138138
template=template_chunks_csv_prompt,
139139
input_variables=["question"],
@@ -144,24 +144,21 @@ def execute(self, state):
144144
},
145145
)
146146

147-
# Dynamically name the chains based on their index
148147
chain_name = f"chunk{i+1}"
149148
chains_dict[chain_name] = prompt | self.llm_model | output_parser
150149

151-
if len(chains_dict) > 1:
152-
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
153-
map_chain = RunnableParallel(**chains_dict)
154-
# Chain
155-
answer = map_chain.invoke({"question": user_prompt})
156-
# Merge the answers from the chunks
157-
merge_prompt = PromptTemplate(
158-
template=template_merge_csv_prompt,
150+
async_runner = RunnableParallel(**chains_dict)
151+
152+
batch_results = async_runner.invoke({"question": user_prompt})
153+
154+
merge_prompt = PromptTemplate(
155+
template = template_merge_csv_prompt,
159156
input_variables=["context", "question"],
160157
partial_variables={"format_instructions": format_instructions},
161158
)
162-
merge_chain = merge_prompt | self.llm_model | output_parser
163-
answer = merge_chain.invoke({"context": answer, "question": user_prompt})
164159

165-
# Update the state with the generated answer
160+
merge_chain = merge_prompt | self.llm_model | output_parser
161+
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
162+
166163
state.update({self.output[0]: answer})
167-
return state
164+
return state

0 commit comments

Comments
 (0)