Skip to content

Commit 745ab2b

Browse files
committed
👍🏼 Update Demo with Ingestion Retrival and also Query - Adithya S K
1 parent 225a2ba commit 745ab2b

File tree

1 file changed

+129
-60
lines changed

1 file changed

+129
-60
lines changed

demo.py

Lines changed: 129 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
from collections import namedtuple
1212
import pandas as pd
13-
13+
import concurrent.futures
1414
from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG
1515
from varag.vlms import OpenAI
1616
from varag.llms import OpenAI as OpenAILLM
@@ -20,7 +20,7 @@
2020
load_dotenv()
2121

2222
# Initialize shared database
23-
shared_db = lancedb.connect("~/demo_rag_db")
23+
shared_db = lancedb.connect("~/rag_db")
2424

2525
# Initialize embedding models
2626
text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
@@ -62,8 +62,6 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
6262
total_start_time = time.time()
6363
progress_data = []
6464

65-
progress(0, desc="Starting ingestion process")
66-
6765
# SimpleRAG
6866
yield IngestResult(
6967
status_text="Starting SimpleRAG ingestion...\n",
@@ -148,27 +146,80 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
148146
)
149147

150148

151-
def retrieve_data(query, top_k):
149+
def retrieve_data(query, top_k, sequential=False):
150+
results = {}
151+
timings = {}
152+
153+
def retrieve_simple():
154+
start_time = time.time()
155+
simple_results = simple_rag.search(query, k=top_k)
156+
157+
print(simple_results)
158+
159+
simple_context = "\n".join([r["text"] for r in simple_results])
160+
end_time = time.time()
161+
return "SimpleRAG", simple_context, end_time - start_time
162+
163+
def retrieve_vision():
164+
start_time = time.time()
165+
vision_results = vision_rag.search(query, k=top_k)
166+
vision_images = [r["image"] for r in vision_results]
167+
end_time = time.time()
168+
return "VisionRAG", vision_images, end_time - start_time
169+
170+
def retrieve_colpali():
171+
start_time = time.time()
172+
colpali_results = colpali_rag.search(query, k=top_k)
173+
colpali_images = [r["image"] for r in colpali_results]
174+
end_time = time.time()
175+
return "ColpaliRAG", colpali_images, end_time - start_time
176+
177+
def retrieve_hybrid():
178+
start_time = time.time()
179+
hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
180+
hybrid_images = [r["image"] for r in hybrid_results]
181+
end_time = time.time()
182+
return "HybridColpaliRAG", hybrid_images, end_time - start_time
183+
184+
retrieval_functions = [
185+
retrieve_simple,
186+
retrieve_vision,
187+
retrieve_colpali,
188+
retrieve_hybrid,
189+
]
190+
191+
if sequential:
192+
for func in retrieval_functions:
193+
rag_type, content, timing = func()
194+
results[rag_type] = content
195+
timings[rag_type] = timing
196+
else:
197+
with concurrent.futures.ThreadPoolExecutor() as executor:
198+
future_results = [executor.submit(func) for func in retrieval_functions]
199+
for future in concurrent.futures.as_completed(future_results):
200+
rag_type, content, timing = future.result()
201+
results[rag_type] = content
202+
timings[rag_type] = timing
203+
204+
return results, timings
205+
206+
207+
def query_data(query, retrieved_results):
152208
results = {}
153209

154210
# SimpleRAG
155-
simple_results = simple_rag.search(query, k=top_k)
156-
simple_context = "\n".join([r["text"] for r in simple_results])
157-
simple_response = vlm.query(
211+
simple_context = retrieved_results["SimpleRAG"]
212+
simple_response = llm.query(
158213
context=simple_context,
159214
system_prompt="Given the below information answer the questions",
160215
query=query,
161216
)
162217
results["SimpleRAG"] = {"response": simple_response, "context": simple_context}
163218

164219
# VisionRAG
165-
vision_results = vision_rag.search(query, k=top_k)
166-
vision_images = [r["image"] for r in vision_results]
220+
vision_images = retrieved_results["VisionRAG"]
167221
vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
168-
[
169-
f"Image {i+1}: From document '{r['document_name']}', page {r['page_number']}"
170-
for i, r in enumerate(vision_results)
171-
]
222+
[f"Image {i+1}" for i in range(len(vision_images))]
172223
)
173224
vision_response = vlm.query(vision_context, vision_images, max_tokens=500)
174225
results["VisionRAG"] = {
@@ -178,13 +229,9 @@ def retrieve_data(query, top_k):
178229
}
179230

180231
# ColpaliRAG
181-
colpali_results = colpali_rag.search(query, k=top_k)
182-
colpali_images = [r["image"] for r in colpali_results]
232+
colpali_images = retrieved_results["ColpaliRAG"]
183233
colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
184-
[
185-
f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..."
186-
for i, r in enumerate(colpali_results)
187-
]
234+
[f"Image {i+1}" for i in range(len(colpali_images))]
188235
)
189236
colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500)
190237
results["ColpaliRAG"] = {
@@ -194,13 +241,9 @@ def retrieve_data(query, top_k):
194241
}
195242

196243
# HybridColpaliRAG
197-
hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
198-
hybrid_images = [r["image"] for r in hybrid_results]
244+
hybrid_images = retrieved_results["HybridColpaliRAG"]
199245
hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
200-
[
201-
f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..."
202-
for i, r in enumerate(hybrid_results)
203-
]
246+
[f"Image {i+1}" for i in range(len(hybrid_images))]
204247
)
205248
hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500)
206249
results["HybridColpaliRAG"] = {
@@ -238,42 +281,34 @@ def gradio_interface():
238281
50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)"
239282
)
240283
ingest_button = gr.Button("Ingest PDFs")
241-
ingest_output = gr.Textbox(label="Ingestion Status", lines=10)
284+
ingest_output = gr.Markdown(label="Ingestion Status", lines=10)
242285
progress_table = gr.DataFrame(
243286
label="Ingestion Progress", headers=["Technique", "Time Taken (s)"]
244287
)
245288

246-
with gr.Tab("Retrieve Data"):
289+
with gr.Tab("Retrieve and Query Data"):
247290
query_input = gr.Textbox(label="Enter your query")
248291
top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results")
249-
search_button = gr.Button("Search and Analyze")
292+
sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False)
293+
retrieve_button = gr.Button("Retrieve")
294+
query_button = gr.Button("Query")
295+
296+
retrieval_timing = gr.DataFrame(
297+
label="Retrieval Timings", headers=["RAG Type", "Time (s)"]
298+
)
299+
300+
with gr.Row():
301+
simple_content = gr.Markdown(label="SimpleRAG Content")
302+
vision_gallery = gr.Gallery(label="VisionRAG Images")
303+
colpali_gallery = gr.Gallery(label="ColpaliRAG Images")
304+
hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images")
250305

251306
with gr.Row():
252307
simple_response = gr.Markdown(label="SimpleRAG Response")
253308
vision_response = gr.Markdown(label="VisionRAG Response")
254309
colpali_response = gr.Markdown(label="ColpaliRAG Response")
255310
hybrid_response = gr.Markdown(label="HybridColpaliRAG Response")
256311

257-
with gr.Row():
258-
simple_context = gr.Accordion("SimpleRAG Context", open=False)
259-
with simple_context:
260-
gr.Markdown(elem_id="simple_context")
261-
262-
vision_context = gr.Accordion("VisionRAG Context", open=False)
263-
with vision_context:
264-
gr.Markdown(elem_id="vision_context")
265-
gr.Gallery(label="VisionRAG Images")
266-
267-
colpali_context = gr.Accordion("ColpaliRAG Context", open=False)
268-
with colpali_context:
269-
gr.Markdown(elem_id="colpali_context")
270-
gr.Gallery(label="ColpaliRAG Images")
271-
272-
hybrid_context = gr.Accordion("HybridColpaliRAG Context", open=False)
273-
with hybrid_context:
274-
gr.Markdown(elem_id="hybrid_context")
275-
gr.Gallery(label="HybridColpaliRAG Images")
276-
277312
with gr.Tab("Settings"):
278313
api_key_input = gr.Textbox(label="OpenAI API Key", type="password")
279314
update_api_button = gr.Button("Update API Key")
@@ -294,27 +329,61 @@ def gradio_interface():
294329
update_table_button = gr.Button("Update Table Names")
295330
table_update_status = gr.Textbox(label="Table Update Status")
296331

297-
ingest_button.click(
298-
ingest_data,
299-
inputs=[pdf_input, use_ocr, chunk_size],
300-
outputs=[ingest_output, progress_table],
332+
retrieved_results = gr.State({})
333+
334+
def update_retrieval_results(query, top_k, sequential):
335+
results, timings = retrieve_data(query, top_k, sequential)
336+
timing_df = pd.DataFrame(
337+
list(timings.items()), columns=["RAG Type", "Time (s)"]
338+
)
339+
return (
340+
results["SimpleRAG"],
341+
results["VisionRAG"],
342+
results["ColpaliRAG"],
343+
results["HybridColpaliRAG"],
344+
timing_df,
345+
results,
346+
)
347+
348+
retrieve_button.click(
349+
update_retrieval_results,
350+
inputs=[query_input, top_k_slider, sequential_checkbox],
351+
outputs=[
352+
simple_content,
353+
vision_gallery,
354+
colpali_gallery,
355+
hybrid_gallery,
356+
retrieval_timing,
357+
retrieved_results,
358+
],
301359
)
302360

303-
search_button.click(
304-
retrieve_data,
305-
inputs=[query_input, top_k_slider],
361+
def update_query_results(query, retrieved_results):
362+
results = query_data(query, retrieved_results)
363+
return (
364+
results["SimpleRAG"]["response"],
365+
results["VisionRAG"]["response"],
366+
results["ColpaliRAG"]["response"],
367+
results["HybridColpaliRAG"]["response"],
368+
)
369+
370+
query_button.click(
371+
update_query_results,
372+
inputs=[query_input, retrieved_results],
306373
outputs=[
307374
simple_response,
308375
vision_response,
309376
colpali_response,
310377
hybrid_response,
311-
simple_context,
312-
vision_context,
313-
colpali_context,
314-
hybrid_context,
315378
],
316379
)
317380

381+
ingest_button.click(
382+
ingest_data,
383+
inputs=[pdf_input, use_ocr, chunk_size],
384+
outputs=[ingest_output, progress_table],
385+
)
386+
318387
update_api_button.click(
319388
update_api_key, inputs=[api_key_input], outputs=api_update_status
320389
)

0 commit comments

Comments
 (0)