1010import  time 
1111from  collections  import  namedtuple 
1212import  pandas  as  pd 
13- 
13+ import   concurrent . futures 
1414from  varag .rag  import  SimpleRAG , VisionRAG , ColpaliRAG , HybridColpaliRAG 
1515from  varag .vlms  import  OpenAI 
1616from  varag .llms  import  OpenAI  as  OpenAILLM 
2020load_dotenv ()
2121
2222# Initialize shared database 
23- shared_db  =  lancedb .connect ("~/demo_rag_db " )
23+ shared_db  =  lancedb .connect ("~/rag_db " )
2424
2525# Initialize embedding models 
2626text_embedding_model  =  SentenceTransformer ("all-MiniLM-L6-v2" , trust_remote_code = True )
@@ -62,8 +62,6 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
6262    total_start_time  =  time .time ()
6363    progress_data  =  []
6464
65-     progress (0 , desc = "Starting ingestion process" )
66- 
6765    # SimpleRAG 
6866    yield  IngestResult (
6967        status_text = "Starting SimpleRAG ingestion...\n " ,
@@ -148,27 +146,80 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
148146    )
149147
150148
151- def  retrieve_data (query , top_k ):
149+ def  retrieve_data (query , top_k , sequential = False ):
150+     results  =  {}
151+     timings  =  {}
152+ 
153+     def  retrieve_simple ():
154+         start_time  =  time .time ()
155+         simple_results  =  simple_rag .search (query , k = top_k )
156+ 
157+         print (simple_results )
158+ 
159+         simple_context  =  "\n " .join ([r ["text" ] for  r  in  simple_results ])
160+         end_time  =  time .time ()
161+         return  "SimpleRAG" , simple_context , end_time  -  start_time 
162+ 
163+     def  retrieve_vision ():
164+         start_time  =  time .time ()
165+         vision_results  =  vision_rag .search (query , k = top_k )
166+         vision_images  =  [r ["image" ] for  r  in  vision_results ]
167+         end_time  =  time .time ()
168+         return  "VisionRAG" , vision_images , end_time  -  start_time 
169+ 
170+     def  retrieve_colpali ():
171+         start_time  =  time .time ()
172+         colpali_results  =  colpali_rag .search (query , k = top_k )
173+         colpali_images  =  [r ["image" ] for  r  in  colpali_results ]
174+         end_time  =  time .time ()
175+         return  "ColpaliRAG" , colpali_images , end_time  -  start_time 
176+ 
177+     def  retrieve_hybrid ():
178+         start_time  =  time .time ()
179+         hybrid_results  =  hybrid_rag .search (query , k = top_k , use_image_search = True )
180+         hybrid_images  =  [r ["image" ] for  r  in  hybrid_results ]
181+         end_time  =  time .time ()
182+         return  "HybridColpaliRAG" , hybrid_images , end_time  -  start_time 
183+ 
184+     retrieval_functions  =  [
185+         retrieve_simple ,
186+         retrieve_vision ,
187+         retrieve_colpali ,
188+         retrieve_hybrid ,
189+     ]
190+ 
191+     if  sequential :
192+         for  func  in  retrieval_functions :
193+             rag_type , content , timing  =  func ()
194+             results [rag_type ] =  content 
195+             timings [rag_type ] =  timing 
196+     else :
197+         with  concurrent .futures .ThreadPoolExecutor () as  executor :
198+             future_results  =  [executor .submit (func ) for  func  in  retrieval_functions ]
199+             for  future  in  concurrent .futures .as_completed (future_results ):
200+                 rag_type , content , timing  =  future .result ()
201+                 results [rag_type ] =  content 
202+                 timings [rag_type ] =  timing 
203+ 
204+     return  results , timings 
205+ 
206+ 
207+ def  query_data (query , retrieved_results ):
152208    results  =  {}
153209
154210    # SimpleRAG 
155-     simple_results  =  simple_rag .search (query , k = top_k )
156-     simple_context  =  "\n " .join ([r ["text" ] for  r  in  simple_results ])
157-     simple_response  =  vlm .query (
211+     simple_context  =  retrieved_results ["SimpleRAG" ]
212+     simple_response  =  llm .query (
158213        context = simple_context ,
159214        system_prompt = "Given the below information answer the questions" ,
160215        query = query ,
161216    )
162217    results ["SimpleRAG" ] =  {"response" : simple_response , "context" : simple_context }
163218
164219    # VisionRAG 
165-     vision_results  =  vision_rag .search (query , k = top_k )
166-     vision_images  =  [r ["image" ] for  r  in  vision_results ]
220+     vision_images  =  retrieved_results ["VisionRAG" ]
167221    vision_context  =  f"Query: { query } \n \n Relevant image information:\n "  +  "\n " .join (
168-         [
169-             f"Image { i + 1 }  : From document '{ r ['document_name' ]}  ', page { r ['page_number' ]}  " 
170-             for  i , r  in  enumerate (vision_results )
171-         ]
222+         [f"Image { i + 1 }  "  for  i  in  range (len (vision_images ))]
172223    )
173224    vision_response  =  vlm .query (vision_context , vision_images , max_tokens = 500 )
174225    results ["VisionRAG" ] =  {
@@ -178,13 +229,9 @@ def retrieve_data(query, top_k):
178229    }
179230
180231    # ColpaliRAG 
181-     colpali_results  =  colpali_rag .search (query , k = top_k )
182-     colpali_images  =  [r ["image" ] for  r  in  colpali_results ]
232+     colpali_images  =  retrieved_results ["ColpaliRAG" ]
183233    colpali_context  =  f"Query: { query } \n \n Relevant image information:\n "  +  "\n " .join (
184-         [
185-             f"Image { i + 1 }  : From document '{ r ['name' ]}  ', page { r ['page_number' ]} \n Text: { r ['page_text' ][:500 ]}  ..." 
186-             for  i , r  in  enumerate (colpali_results )
187-         ]
234+         [f"Image { i + 1 }  "  for  i  in  range (len (colpali_images ))]
188235    )
189236    colpali_response  =  vlm .query (colpali_context , colpali_images , max_tokens = 500 )
190237    results ["ColpaliRAG" ] =  {
@@ -194,13 +241,9 @@ def retrieve_data(query, top_k):
194241    }
195242
196243    # HybridColpaliRAG 
197-     hybrid_results  =  hybrid_rag .search (query , k = top_k , use_image_search = True )
198-     hybrid_images  =  [r ["image" ] for  r  in  hybrid_results ]
244+     hybrid_images  =  retrieved_results ["HybridColpaliRAG" ]
199245    hybrid_context  =  f"Query: { query } \n \n Relevant image information:\n "  +  "\n " .join (
200-         [
201-             f"Image { i + 1 }  : From document '{ r ['name' ]}  ', page { r ['page_number' ]} \n Text: { r ['page_text' ][:500 ]}  ..." 
202-             for  i , r  in  enumerate (hybrid_results )
203-         ]
246+         [f"Image { i + 1 }  "  for  i  in  range (len (hybrid_images ))]
204247    )
205248    hybrid_response  =  vlm .query (hybrid_context , hybrid_images , max_tokens = 500 )
206249    results ["HybridColpaliRAG" ] =  {
@@ -238,42 +281,34 @@ def gradio_interface():
238281                50 , 5000 , value = 200 , step = 10 , label = "Chunk Size (for SimpleRAG)" 
239282            )
240283            ingest_button  =  gr .Button ("Ingest PDFs" )
241-             ingest_output  =  gr .Textbox (label = "Ingestion Status" , lines = 10 )
284+             ingest_output  =  gr .Markdown (label = "Ingestion Status" , lines = 10 )
242285            progress_table  =  gr .DataFrame (
243286                label = "Ingestion Progress" , headers = ["Technique" , "Time Taken (s)" ]
244287            )
245288
246-         with  gr .Tab ("Retrieve Data" ):
289+         with  gr .Tab ("Retrieve and Query  Data" ):
247290            query_input  =  gr .Textbox (label = "Enter your query" )
248291            top_k_slider  =  gr .Slider (1 , 10 , value = 3 , step = 1 , label = "Top K Results" )
249-             search_button  =  gr .Button ("Search and Analyze" )
292+             sequential_checkbox  =  gr .Checkbox (label = "Sequential Retrieval" , value = False )
293+             retrieve_button  =  gr .Button ("Retrieve" )
294+             query_button  =  gr .Button ("Query" )
295+ 
296+             retrieval_timing  =  gr .DataFrame (
297+                 label = "Retrieval Timings" , headers = ["RAG Type" , "Time (s)" ]
298+             )
299+ 
300+             with  gr .Row ():
301+                 simple_content  =  gr .Markdown (label = "SimpleRAG Content" )
302+                 vision_gallery  =  gr .Gallery (label = "VisionRAG Images" )
303+                 colpali_gallery  =  gr .Gallery (label = "ColpaliRAG Images" )
304+                 hybrid_gallery  =  gr .Gallery (label = "HybridColpaliRAG Images" )
250305
251306            with  gr .Row ():
252307                simple_response  =  gr .Markdown (label = "SimpleRAG Response" )
253308                vision_response  =  gr .Markdown (label = "VisionRAG Response" )
254309                colpali_response  =  gr .Markdown (label = "ColpaliRAG Response" )
255310                hybrid_response  =  gr .Markdown (label = "HybridColpaliRAG Response" )
256311
257-             with  gr .Row ():
258-                 simple_context  =  gr .Accordion ("SimpleRAG Context" , open = False )
259-                 with  simple_context :
260-                     gr .Markdown (elem_id = "simple_context" )
261- 
262-                 vision_context  =  gr .Accordion ("VisionRAG Context" , open = False )
263-                 with  vision_context :
264-                     gr .Markdown (elem_id = "vision_context" )
265-                     gr .Gallery (label = "VisionRAG Images" )
266- 
267-                 colpali_context  =  gr .Accordion ("ColpaliRAG Context" , open = False )
268-                 with  colpali_context :
269-                     gr .Markdown (elem_id = "colpali_context" )
270-                     gr .Gallery (label = "ColpaliRAG Images" )
271- 
272-                 hybrid_context  =  gr .Accordion ("HybridColpaliRAG Context" , open = False )
273-                 with  hybrid_context :
274-                     gr .Markdown (elem_id = "hybrid_context" )
275-                     gr .Gallery (label = "HybridColpaliRAG Images" )
276- 
277312        with  gr .Tab ("Settings" ):
278313            api_key_input  =  gr .Textbox (label = "OpenAI API Key" , type = "password" )
279314            update_api_button  =  gr .Button ("Update API Key" )
@@ -294,27 +329,61 @@ def gradio_interface():
294329            update_table_button  =  gr .Button ("Update Table Names" )
295330            table_update_status  =  gr .Textbox (label = "Table Update Status" )
296331
297-         ingest_button .click (
298-             ingest_data ,
299-             inputs = [pdf_input , use_ocr , chunk_size ],
300-             outputs = [ingest_output , progress_table ],
332+         retrieved_results  =  gr .State ({})
333+ 
334+         def  update_retrieval_results (query , top_k , sequential ):
335+             results , timings  =  retrieve_data (query , top_k , sequential )
336+             timing_df  =  pd .DataFrame (
337+                 list (timings .items ()), columns = ["RAG Type" , "Time (s)" ]
338+             )
339+             return  (
340+                 results ["SimpleRAG" ],
341+                 results ["VisionRAG" ],
342+                 results ["ColpaliRAG" ],
343+                 results ["HybridColpaliRAG" ],
344+                 timing_df ,
345+                 results ,
346+             )
347+ 
348+         retrieve_button .click (
349+             update_retrieval_results ,
350+             inputs = [query_input , top_k_slider , sequential_checkbox ],
351+             outputs = [
352+                 simple_content ,
353+                 vision_gallery ,
354+                 colpali_gallery ,
355+                 hybrid_gallery ,
356+                 retrieval_timing ,
357+                 retrieved_results ,
358+             ],
301359        )
302360
303-         search_button .click (
304-             retrieve_data ,
305-             inputs = [query_input , top_k_slider ],
361+         def  update_query_results (query , retrieved_results ):
362+             results  =  query_data (query , retrieved_results )
363+             return  (
364+                 results ["SimpleRAG" ]["response" ],
365+                 results ["VisionRAG" ]["response" ],
366+                 results ["ColpaliRAG" ]["response" ],
367+                 results ["HybridColpaliRAG" ]["response" ],
368+             )
369+ 
370+         query_button .click (
371+             update_query_results ,
372+             inputs = [query_input , retrieved_results ],
306373            outputs = [
307374                simple_response ,
308375                vision_response ,
309376                colpali_response ,
310377                hybrid_response ,
311-                 simple_context ,
312-                 vision_context ,
313-                 colpali_context ,
314-                 hybrid_context ,
315378            ],
316379        )
317380
381+         ingest_button .click (
382+             ingest_data ,
383+             inputs = [pdf_input , use_ocr , chunk_size ],
384+             outputs = [ingest_output , progress_table ],
385+         )
386+ 
318387        update_api_button .click (
319388            update_api_key , inputs = [api_key_input ], outputs = api_update_status 
320389        )
0 commit comments