diff --git a/.gitignore b/.gitignore index c1750078..aa84820c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ docs/source/_static/ venv/ .venv/ .vscode/ +.conda/ # exclude pdf, mp3 *.pdf @@ -38,3 +39,6 @@ lib/ *.html .idea +# extras +cache/ +run_smart_scraper.py diff --git a/docs/source/scrapers/graph_config.rst b/docs/source/scrapers/graph_config.rst index 6b046d5b..9e1d49e0 100644 --- a/docs/source/scrapers/graph_config.rst +++ b/docs/source/scrapers/graph_config.rst @@ -13,6 +13,7 @@ Some interesting ones are: - `loader_kwargs`: A dictionary with additional parameters to be passed to the `Loader` class, such as `proxy`. - `burr_kwargs`: A dictionary with additional parameters to enable `Burr` graphical user interface. - `max_images`: The maximum number of images to be analyzed. Useful in `OmniScraperGraph` and `OmniSearchGraph`. +- `cache_path`: The path where the cache files will be saved. If already exists, the cache will be loaded from this path. .. _Burr: diff --git a/requirements-dev.txt b/requirements-dev.txt index 13f2257f..d33296d5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ sphinx==7.1.2 furo==2024.5.6 pytest==8.0.0 -burr[start]==0.19.1 \ No newline at end of file +burr[start]==0.22.1 \ No newline at end of file diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 7814efa8..70a81401 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -76,6 +76,7 @@ def __init__(self, prompt: str, config: dict, self.headless = True if config is None else config.get( "headless", True) self.loader_kwargs = config.get("loader_kwargs", {}) + self.cache_path = config.get("cache_path", False) # Create the graph self.graph = self._create_graph() @@ -91,15 +92,13 @@ def __init__(self, prompt: str, config: dict, else: set_verbosity_warning() - self.headless = True if config is None else config.get("headless", True) - self.loader_kwargs = config.get("loader_kwargs", {}) - common_params = { "headless": self.headless, "verbose": self.verbose, "loader_kwargs": self.loader_kwargs, "llm_model": self.llm_model, - "embedder_model": self.embedder_model + "embedder_model": self.embedder_model, + "cache_path": self.cache_path, } self.set_common_params(common_params, overwrite=False) diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index 6d26bd1c..a4f58191 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -3,6 +3,7 @@ """ from typing import List, Optional +import os from langchain.docstore.document import Document from langchain.retrievers import ContextualCompressionRetriever @@ -50,6 +51,7 @@ def __init__( self.verbose = ( False if node_config is None else node_config.get("verbose", False) ) + self.cache_path = node_config.get("cache_path", False) def execute(self, state: dict) -> dict: """ @@ -98,7 +100,24 @@ def execute(self, state: dict) -> dict: ) embeddings = self.embedder_model - retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever() + folder_name = self.node_config.get("cache_path", "cache") + + if self.node_config.get("cache_path", False) and not os.path.exists(folder_name): + index = FAISS.from_documents(chunked_docs, embeddings) + os.makedirs(folder_name) + index.save_local(folder_name) + self.logger.info("--- (indexes saved to cache) ---") + + elif self.node_config.get("cache_path", False) and os.path.exists(folder_name): + index = FAISS.load_local(folder_path=folder_name, + embeddings=embeddings, + allow_dangerous_deserialization=True) + self.logger.info("--- (indexes loaded from cache) ---") + + else: + index = FAISS.from_documents(chunked_docs, embeddings) + + retriever = index.as_retriever() redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) # similarity_threshold could be set, now k=20