diff --git a/LLM/09_rag_langchain.ipynb b/LLM/09_rag_langchain.ipynb index f26045c..86833d0 100644 --- a/LLM/09_rag_langchain.ipynb +++ b/LLM/09_rag_langchain.ipynb @@ -627,7 +627,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "#from langchain_community.vectorstores.chroma import Chroma\n", @@ -703,6 +705,251 @@ " print(f\"An error occurred: {str(e)}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ipywidgets Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ipywidgets as widgets\n", + "from IPython.display import display, clear_output\n", + "import time\n", + "import os\n", + "import sys\n", + "\n", + "# Set up widgets for the interface\n", + "source_type_dropdown = widgets.Dropdown(\n", + " options=['url', 'local'],\n", + " value='local',\n", + " description='Source Type:',\n", + " style={'description_width': 'initial'},\n", + " layout=widgets.Layout(width='300px')\n", + ")\n", + "\n", + "source_input = widgets.Text(\n", + " description='Source:',\n", + " placeholder='Enter URL or filename from data folder',\n", + " style={'description_width': 'initial'},\n", + " layout=widgets.Layout(width='600px')\n", + ")\n", + "\n", + "file_list = widgets.Select(\n", + " options=[],\n", + " description='Available Files:',\n", + " disabled=False,\n", + " layout=widgets.Layout(width='400px', height='150px', display='none')\n", + ")\n", + "\n", + "question_input = widgets.Text(\n", + " placeholder='Ask a question about the document...',\n", + " layout=widgets.Layout(width='800px')\n", + ")\n", + "\n", + "submit_button = widgets.Button(\n", + " description='Ask',\n", + " button_style='primary',\n", + " layout=widgets.Layout(width='100px')\n", + ")\n", + "\n", + "exit_button = widgets.Button(\n", + " description='Exit',\n", + " button_style='danger',\n", + " layout=widgets.Layout(width='100px')\n", + ")\n", + "\n", + "model_status = widgets.HTML(value=\"Status: Ready to load document\")\n", + "response_area = widgets.Output(layout=widgets.Layout(border='1px solid #ddd', padding='10px', width='900px', height='300px'))\n", + "\n", + "# Main widget container (for easier cleanup)\n", + "main_container = widgets.VBox()\n", + "\n", + "# Setup components\n", + "embedding_fn = initialize_embedding_fn()\n", + "vector_store = None\n", + "qachain = None\n", + "\n", + "model_name_or_path = \"TheBloke/Llama-2-7B-Chat-GGUF\"\n", + "model_basename = \"llama-2-7b-chat.Q4_K_M.gguf\"\n", + "MODEL_PATH = hf_hub_download(repo_id=model_name_or_path, filename=model_basename)\n", + "chat_model = create_llm(MODEL_PATH)\n", + "\n", + "# Create base prompt template\n", + "prompt_template = \"\"\"\n", + "Answer the question based on the context below. Keep your answer concise.\n", + "If you don't know, just say \"I don't know.\"\n", + "\n", + "Context: {context}\n", + "\n", + "Question: {question}\n", + "Answer:\n", + "\"\"\"\n", + "prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n", + "chain_type_kwargs = {\"prompt\": prompt}\n", + "\n", + "# Function to update file list\n", + "def update_file_list():\n", + " data_dir = os.path.join(os.getcwd(), \"data\")\n", + " if os.path.exists(data_dir):\n", + " files = os.listdir(data_dir)\n", + " file_list.options = files\n", + " if files:\n", + " file_list.value = files[0]\n", + " else:\n", + " file_list.options = [\"No files found in data directory\"]\n", + "\n", + "# Handler for source type change\n", + "def source_type_changed(change):\n", + " if change['new'] == 'local':\n", + " update_file_list()\n", + " file_list.layout.display = 'block'\n", + " else:\n", + " file_list.layout.display = 'none'\n", + "\n", + "# Handler for file selection\n", + "def file_selected(change):\n", + " if change['new'] and source_type_dropdown.value == 'local':\n", + " source_input.value = change['new']\n", + "\n", + "# Widget handlers\n", + "def load_document_button_click(b):\n", + " global vector_store, qachain\n", + " \n", + " with response_area:\n", + " clear_output()\n", + " source = source_input.value\n", + " source_type = source_type_dropdown.value\n", + " \n", + " if not source:\n", + " print(\"Please enter a source URL or filename\")\n", + " return\n", + " \n", + " model_status.value = \"Status: Loading document...\"\n", + " try:\n", + " # Get or create embeddings for the current document\n", + " vector_store = get_or_create_embeddings(source, source_type, embedding_fn)\n", + " \n", + " # Setup retriever and chain\n", + " retriever = vector_store.as_retriever(search_kwargs={\"k\": 4})\n", + " qachain = chains.RetrievalQA.from_chain_type(\n", + " llm=chat_model,\n", + " retriever=retriever,\n", + " chain_type=\"stuff\",\n", + " chain_type_kwargs=chain_type_kwargs,\n", + " return_source_documents=False\n", + " )\n", + " \n", + " model_status.value = \"Status: Model ready! Ask your questions.\"\n", + " print(\"Document loaded successfully! You can now ask questions.\")\n", + " except Exception as e:\n", + " model_status.value = f\"Status: Error loading document\"\n", + " print(f\"Error: {str(e)}\")\n", + "\n", + "class WidgetStreamHandler(BaseCallbackHandler):\n", + " def __init__(self, output_widget):\n", + " self.output_widget = output_widget\n", + " self.generated_text = \"\"\n", + " \n", + " def on_llm_new_token(self, token, **kwargs):\n", + " self.generated_text += token\n", + " with self.output_widget:\n", + " clear_output(wait=True)\n", + " print(self.generated_text)\n", + "\n", + "def ask_question_button_click(b):\n", + " question = question_input.value\n", + " \n", + " if not question:\n", + " with response_area:\n", + " clear_output()\n", + " print(\"Please enter a question\")\n", + " return\n", + " \n", + " if vector_store is None or qachain is None:\n", + " with response_area:\n", + " clear_output()\n", + " print(\"Please load a document first\")\n", + " return\n", + " \n", + " with response_area:\n", + " clear_output()\n", + " model_status.value = \"Status: Generating response...\"\n", + " stream_handler = WidgetStreamHandler(response_area)\n", + " \n", + " try:\n", + " start_time = time.time()\n", + " qachain.invoke(\n", + " {\n", + " \"query\": question,\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0.7\n", + " },\n", + " config={\"callbacks\": [stream_handler]}\n", + " )\n", + " elapsed = time.time() - start_time\n", + " model_status.value = f\"Status: Response generated in {elapsed:.2f} seconds\"\n", + " \n", + " except Exception as e:\n", + " model_status.value = \"Status: Error generating response\"\n", + " print(f\"Error: {str(e)}\")\n", + "\n", + "def exit_application(b):\n", + " # Clear all widget outputs\n", + " with response_area:\n", + " clear_output()\n", + " \n", + " # Clean up resources\n", + " global vector_store, qachain\n", + " vector_store = None\n", + " qachain = None\n", + " \n", + " # Clear the main container and show exit message\n", + " main_container.children = [widgets.HTML(\"