diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f85a255..19b3b0db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,18 +1,29 @@ -## [1.37.0](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.36.0...v1.37.0) (2025-01-21) +## [1.37.1-beta.1](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.37.0...v1.37.1-beta.1) (2025-01-22) -### Features +### Bug Fixes + +* Schema parameter type ([2b5bd80](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/2b5bd80a945a24072e578133eacc751feeec6188)) + + +### CI + +* **release:** 1.36.1-beta.1 [skip ci] ([006a2aa](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/006a2aaa3fbafbd5b2030c48d5b04b605532c06f)) + +## [1.36.1-beta.1](https://github.com/ScrapeGraphAI/Scrapegraph-ai/compare/v1.36.0...v1.36.1-beta.1) (2025-01-21) -* add integration for search on web ([224ff07](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/224ff07032d006d75160a7094366fac17023aca1)) ### Bug Fixes +* Schema parameter type ([2b5bd80](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/2b5bd80a945a24072e578133eacc751feeec6188)) * search ([ce25b6a](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/ce25b6a4b0e1ea15edf14a5867f6336bb27590cb)) + ### Docs + * add requirements.dev ([6e12981](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/6e12981e637d078a6d3b3ce83f0d4901e9dd9996)) * added first ollama example ([aa6a76e](https://github.com/ScrapeGraphAI/Scrapegraph-ai/commit/aa6a76e5bdf63544f62786b0d17effa205aab3d8)) diff --git a/codebeaver.yml b/codebeaver.yml new file mode 100644 index 00000000..3fec9f4f --- /dev/null +++ b/codebeaver.yml @@ -0,0 +1,2 @@ +from: pytest +setup_commands: ['@merge', 'pip install -q selenium', 'pip install -q playwright', 'playwright install'] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 77b2cab9..b2300791 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [project] name = "scrapegraphai" -version = "1.37.0" + +version = "1.37.1b1" description = "A web scraping library based on LangChain which uses LLM and direct graph logic to create scraping pipelines." diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index a56c9954..86311cb6 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -6,7 +6,7 @@ import uuid import warnings from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Type from langchain.chat_models import init_chat_model from langchain_core.rate_limiters import InMemoryRateLimiter @@ -51,7 +51,7 @@ def __init__( prompt: str, config: dict, source: Optional[str] = None, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): if config.get("llm").get("temperature") is None: config["llm"]["temperature"] = 0 diff --git a/scrapegraphai/graphs/code_generator_graph.py b/scrapegraphai/graphs/code_generator_graph.py index 5b5b23d8..506b87a1 100644 --- a/scrapegraphai/graphs/code_generator_graph.py +++ b/scrapegraphai/graphs/code_generator_graph.py @@ -2,7 +2,7 @@ SmartScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -56,7 +56,11 @@ class CodeGeneratorGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/csv_scraper_graph.py b/scrapegraphai/graphs/csv_scraper_graph.py index b2bcc712..11a18553 100644 --- a/scrapegraphai/graphs/csv_scraper_graph.py +++ b/scrapegraphai/graphs/csv_scraper_graph.py @@ -2,7 +2,7 @@ Module for creating the smart scraper """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -22,7 +22,7 @@ class CSVScraperGraph(AbstractGraph): config (dict): Additional configuration parameters needed by some nodes in the graph. Methods: - __init__ (prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): + __init__ (prompt: str, source: str, config: dict, schema: Optional[Type[BaseModel]] = None): Initializes the CSVScraperGraph with a prompt, source, and configuration. __init__ initializes the CSVScraperGraph class. It requires the user's prompt as input, @@ -49,7 +49,11 @@ class CSVScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): """ Initializes the CSVScraperGraph with a prompt, source, and configuration. diff --git a/scrapegraphai/graphs/csv_scraper_multi_graph.py b/scrapegraphai/graphs/csv_scraper_multi_graph.py index b495c6d7..5e3a398d 100644 --- a/scrapegraphai/graphs/csv_scraper_multi_graph.py +++ b/scrapegraphai/graphs/csv_scraper_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -47,7 +47,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/depth_search_graph.py b/scrapegraphai/graphs/depth_search_graph.py index 4dd0e49d..b68aa21d 100644 --- a/scrapegraphai/graphs/depth_search_graph.py +++ b/scrapegraphai/graphs/depth_search_graph.py @@ -2,7 +2,7 @@ depth search graph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -54,7 +54,11 @@ class DepthSearchGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/document_scraper_graph.py b/scrapegraphai/graphs/document_scraper_graph.py index 92a0f3d1..012ab0b7 100644 --- a/scrapegraphai/graphs/document_scraper_graph.py +++ b/scrapegraphai/graphs/document_scraper_graph.py @@ -2,7 +2,7 @@ This module implements the Document Scraper Graph for the ScrapeGraphAI application. """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -44,7 +44,11 @@ class DocumentScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/document_scraper_multi_graph.py b/scrapegraphai/graphs/document_scraper_multi_graph.py index 555b3964..edd8c7ef 100644 --- a/scrapegraphai/graphs/document_scraper_multi_graph.py +++ b/scrapegraphai/graphs/document_scraper_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -47,7 +47,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/json_scraper_graph.py b/scrapegraphai/graphs/json_scraper_graph.py index 29e96497..5d0dfbff 100644 --- a/scrapegraphai/graphs/json_scraper_graph.py +++ b/scrapegraphai/graphs/json_scraper_graph.py @@ -2,7 +2,7 @@ JSONScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -42,7 +42,11 @@ class JSONScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/json_scraper_multi_graph.py b/scrapegraphai/graphs/json_scraper_multi_graph.py index 7984c7b9..6623718b 100644 --- a/scrapegraphai/graphs/json_scraper_multi_graph.py +++ b/scrapegraphai/graphs/json_scraper_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -47,7 +47,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/omni_scraper_graph.py b/scrapegraphai/graphs/omni_scraper_graph.py index c2c13f88..ef55dc75 100644 --- a/scrapegraphai/graphs/omni_scraper_graph.py +++ b/scrapegraphai/graphs/omni_scraper_graph.py @@ -2,7 +2,7 @@ This module implements the Omni Scraper Graph for the ScrapeGraphAI application. """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -47,7 +47,11 @@ class OmniScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): self.max_images = 5 if config is None else config.get("max_images", 5) diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index a02e31c6..c30033ab 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -41,7 +41,9 @@ class OmniSearchGraph(AbstractGraph): >>> result = search_graph.run() """ - def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, prompt: str, config: dict, schema: Optional[Type[BaseModel]] = None + ): self.max_results = config.get("max_results", 3) diff --git a/scrapegraphai/graphs/screenshot_scraper_graph.py b/scrapegraphai/graphs/screenshot_scraper_graph.py index c37e34f2..765bd428 100644 --- a/scrapegraphai/graphs/screenshot_scraper_graph.py +++ b/scrapegraphai/graphs/screenshot_scraper_graph.py @@ -2,7 +2,7 @@ ScreenshotScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -21,7 +21,7 @@ class ScreenshotScraperGraph(AbstractGraph): source (str): The source URL or image link to scrape from. Methods: - __init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None) + __init__(prompt: str, source: str, config: dict, schema: Optional[Type[BaseModel]] = None) Initializes the ScreenshotScraperGraph instance with the given prompt, source, and configuration parameters. @@ -33,7 +33,11 @@ class ScreenshotScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index 98dd05e4..4d373a28 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -2,7 +2,7 @@ ScriptCreatorGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -44,7 +44,11 @@ class ScriptCreatorGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): self.library = config["library"] diff --git a/scrapegraphai/graphs/script_creator_multi_graph.py b/scrapegraphai/graphs/script_creator_multi_graph.py index b3f21025..a8e01729 100644 --- a/scrapegraphai/graphs/script_creator_multi_graph.py +++ b/scrapegraphai/graphs/script_creator_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -46,7 +46,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 394afb2a..2458c1d8 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -42,7 +42,9 @@ class SearchGraph(AbstractGraph): >>> print(search_graph.get_considered_urls()) """ - def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, prompt: str, config: dict, schema: Optional[Type[BaseModel]] = None + ): self.max_results = config.get("max_results", 3) self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/search_link_graph.py b/scrapegraphai/graphs/search_link_graph.py index ba781363..c46fe9be 100644 --- a/scrapegraphai/graphs/search_link_graph.py +++ b/scrapegraphai/graphs/search_link_graph.py @@ -2,7 +2,7 @@ SearchLinkGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -36,7 +36,9 @@ class SearchLinkGraph(AbstractGraph): """ - def __init__(self, source: str, config: dict, schema: Optional[BaseModel] = None): + def __init__( + self, source: str, config: dict, schema: Optional[Type[BaseModel]] = None + ): super().__init__("", config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 7719979d..16964d84 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -2,7 +2,7 @@ SmartScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -52,7 +52,11 @@ class SmartScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/smart_scraper_lite_graph.py b/scrapegraphai/graphs/smart_scraper_lite_graph.py index 7769e21b..a845dc7a 100644 --- a/scrapegraphai/graphs/smart_scraper_lite_graph.py +++ b/scrapegraphai/graphs/smart_scraper_lite_graph.py @@ -2,7 +2,7 @@ SmartScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -44,7 +44,7 @@ def __init__( source: str, config: dict, prompt: str = "", - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py index 8c856c01..ebd7b936 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_concat_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -51,7 +51,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/smart_scraper_multi_graph.py b/scrapegraphai/graphs/smart_scraper_multi_graph.py index fa4bfd0f..a0518fb7 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -53,7 +53,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.max_results = config.get("max_results", 3) diff --git a/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py b/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py index ea57bab0..d212b08a 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_lite_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -53,7 +53,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index 32d5be8c..11caea9b 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -2,7 +2,7 @@ SpeechGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -44,7 +44,11 @@ class SpeechGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/xml_scraper_graph.py b/scrapegraphai/graphs/xml_scraper_graph.py index c7dcd62e..162aa322 100644 --- a/scrapegraphai/graphs/xml_scraper_graph.py +++ b/scrapegraphai/graphs/xml_scraper_graph.py @@ -2,7 +2,7 @@ XMLScraperGraph Module """ -from typing import Optional +from typing import Optional, Type from pydantic import BaseModel @@ -44,7 +44,11 @@ class XMLScraperGraph(AbstractGraph): """ def __init__( - self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None + self, + prompt: str, + source: str, + config: dict, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/xml_scraper_multi_graph.py b/scrapegraphai/graphs/xml_scraper_multi_graph.py index f44bf343..2a3848c9 100644 --- a/scrapegraphai/graphs/xml_scraper_multi_graph.py +++ b/scrapegraphai/graphs/xml_scraper_multi_graph.py @@ -3,7 +3,7 @@ """ from copy import deepcopy -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel @@ -47,7 +47,7 @@ def __init__( prompt: str, source: List[str], config: dict, - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): self.copy_config = safe_deepcopy(config) diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index 15ae5524..82171258 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -3,7 +3,7 @@ """ import asyncio -from typing import List, Optional +from typing import List, Optional, Type from pydantic import BaseModel from tqdm.asyncio import tqdm @@ -34,7 +34,7 @@ def __init__( output: List[str], node_config: Optional[dict] = None, node_name: str = "GraphIterator", - schema: Optional[BaseModel] = None, + schema: Optional[Type[BaseModel]] = None, ): super().__init__(node_name, "node", input, output, 2, node_config) diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index c17ef09a..280f1f77 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -1,18 +1,16 @@ -""" -Tests for the AbstractGraph. -""" - -from unittest.mock import patch - import pytest + from langchain_aws import ChatBedrock from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI - from scrapegraphai.graphs import AbstractGraph, BaseGraph from scrapegraphai.models import DeepSeek, OneApi from scrapegraphai.nodes import FetchNode, ParseNode +from unittest.mock import Mock, patch +""" +Tests for the AbstractGraph. +""" class TestGraph(AbstractGraph): def __init__(self, prompt: str, config: dict): @@ -50,7 +48,6 @@ def run(self) -> str: return self.final_state.get("answer", "No answer found.") - class TestAbstractGraph: @pytest.mark.parametrize( "llm_config, expected_model", @@ -161,3 +158,45 @@ async def test_run_safe_async(self): result = await graph.run_safe_async() assert result == "Async result" mock_run.assert_called_once() + + def test_create_llm_with_custom_model_instance(self): + """ + Test that the _create_llm method correctly uses a custom model instance + when provided in the configuration. + """ + mock_model = Mock() + mock_model.model_name = "custom-model" + + config = { + "llm": { + "model_instance": mock_model, + "model_tokens": 1000, + "model": "custom/model" + } + } + + graph = TestGraph("Test prompt", config) + + assert graph.llm_model == mock_model + assert graph.model_token == 1000 + + def test_set_common_params(self): + """ + Test that the set_common_params method correctly updates the configuration + of all nodes in the graph. + """ + # Create a mock graph with mock nodes + mock_graph = Mock() + mock_node1 = Mock() + mock_node2 = Mock() + mock_graph.nodes = [mock_node1, mock_node2] + + # Create a TestGraph instance with the mock graph + with patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_graph', return_value=mock_graph): + graph = TestGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-test"}}) + + # Call set_common_params with test parameters + test_params = {"param1": "value1", "param2": "value2"} + graph.set_common_params(test_params) + + # Assert that update_config was called on each node with the correct parameters \ No newline at end of file diff --git a/tests/test_json_scraper_graph.py b/tests/test_json_scraper_graph.py new file mode 100644 index 00000000..1572650e --- /dev/null +++ b/tests/test_json_scraper_graph.py @@ -0,0 +1,136 @@ +import pytest + +from pydantic import BaseModel +from scrapegraphai.graphs.json_scraper_graph import JSONScraperGraph +from unittest.mock import Mock, patch + +class TestJSONScraperGraph: + @pytest.fixture + def mock_llm_model(self): + return Mock() + + @pytest.fixture + def mock_embedder_model(self): + return Mock() + + @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') + @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') + @patch.object(JSONScraperGraph, '_create_llm') + def test_json_scraper_graph_with_directory(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + """ + Test JSONScraperGraph with a directory of JSON files. + This test checks if the graph correctly handles multiple JSON files input + and processes them to generate an answer. + """ + # Mock the _create_llm method to return a mock LLM model + mock_create_llm.return_value = mock_llm_model + + # Mock the execute method of BaseGraph + with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: + mock_execute.return_value = ({"answer": "Mocked answer for multiple JSON files"}, {}) + + # Create a JSONScraperGraph instance + graph = JSONScraperGraph( + prompt="Summarize the data from all JSON files", + source="path/to/json/directory", + config={"llm": {"model": "test-model", "temperature": 0}}, + schema=BaseModel + ) + + # Set mocked embedder model + graph.embedder_model = mock_embedder_model + + # Run the graph + result = graph.run() + + # Assertions + assert result == "Mocked answer for multiple JSON files" + assert graph.input_key == "json_dir" + mock_execute.assert_called_once_with({"user_prompt": "Summarize the data from all JSON files", "json_dir": "path/to/json/directory"}) + mock_fetch_node.assert_called_once() + mock_generate_answer_node.assert_called_once() + mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + + @pytest.fixture + def mock_llm_model(self): + return Mock() + + @pytest.fixture + def mock_embedder_model(self): + return Mock() + + @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') + @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') + @patch.object(JSONScraperGraph, '_create_llm') + def test_json_scraper_graph_with_single_file(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + """ + Test JSONScraperGraph with a single JSON file. + This test checks if the graph correctly handles a single JSON file input + and processes it to generate an answer. + """ + # Mock the _create_llm method to return a mock LLM model + mock_create_llm.return_value = mock_llm_model + + # Mock the execute method of BaseGraph + with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: + mock_execute.return_value = ({"answer": "Mocked answer for single JSON file"}, {}) + + # Create a JSONScraperGraph instance with a single JSON file + graph = JSONScraperGraph( + prompt="Analyze the data from the JSON file", + source="path/to/single/file.json", + config={"llm": {"model": "test-model", "temperature": 0}}, + schema=BaseModel + ) + + # Set mocked embedder model + graph.embedder_model = mock_embedder_model + + # Run the graph + result = graph.run() + + # Assertions + assert result == "Mocked answer for single JSON file" + assert graph.input_key == "json" + mock_execute.assert_called_once_with({"user_prompt": "Analyze the data from the JSON file", "json": "path/to/single/file.json"}) + mock_fetch_node.assert_called_once() + mock_generate_answer_node.assert_called_once() + mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + + @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') + @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') + @patch.object(JSONScraperGraph, '_create_llm') + def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + """ + Test JSONScraperGraph when no answer is found. + This test checks if the graph correctly handles the scenario where no answer is generated, + ensuring it returns the default "No answer found." message. + """ + # Mock the _create_llm method to return a mock LLM model + mock_create_llm.return_value = mock_llm_model + + # Mock the execute method of BaseGraph to return an empty answer + with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: + mock_execute.return_value = ({}, {}) # Empty state and execution info + + # Create a JSONScraperGraph instance + graph = JSONScraperGraph( + prompt="Query that produces no answer", + source="path/to/empty/file.json", + config={"llm": {"model": "test-model", "temperature": 0}}, + schema=BaseModel + ) + + # Set mocked embedder model + graph.embedder_model = mock_embedder_model + + # Run the graph + result = graph.run() + + # Assertions + assert result == "No answer found." + assert graph.input_key == "json" + mock_execute.assert_called_once_with({"user_prompt": "Query that produces no answer", "json": "path/to/empty/file.json"}) + mock_fetch_node.assert_called_once() + mock_generate_answer_node.assert_called_once() + mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) \ No newline at end of file diff --git a/tests/test_search_graph.py b/tests/test_search_graph.py new file mode 100644 index 00000000..0b8209c0 --- /dev/null +++ b/tests/test_search_graph.py @@ -0,0 +1,82 @@ +import pytest + +from scrapegraphai.graphs.search_graph import SearchGraph +from unittest.mock import MagicMock, call, patch + +class TestSearchGraph: + """Test class for SearchGraph""" + + @pytest.mark.parametrize("urls", [ + ["https://example.com", "https://test.com"], + [], + ["https://single-url.com"] + ]) + @patch('scrapegraphai.graphs.search_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + def test_get_considered_urls(self, mock_create_llm, mock_base_graph, urls): + """ + Test that get_considered_urls returns the correct list of URLs + considered during the search process. + """ + # Arrange + prompt = "Test prompt" + config = {"llm": {"model": "test-model"}} + + # Mock the _create_llm method to return a MagicMock + mock_create_llm.return_value = MagicMock() + + # Mock the execute method to set the final_state + mock_base_graph.return_value.execute.return_value = ({"urls": urls}, {}) + + # Act + search_graph = SearchGraph(prompt, config) + search_graph.run() + + # Assert + assert search_graph.get_considered_urls() == urls + + @patch('scrapegraphai.graphs.search_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + def test_run_no_answer_found(self, mock_create_llm, mock_base_graph): + """ + Test that the run() method returns "No answer found." when the final state + doesn't contain an "answer" key. + """ + # Arrange + prompt = "Test prompt" + config = {"llm": {"model": "test-model"}} + + # Mock the _create_llm method to return a MagicMock + mock_create_llm.return_value = MagicMock() + + # Mock the execute method to set the final_state without an "answer" key + mock_base_graph.return_value.execute.return_value = ({"urls": []}, {}) + + # Act + search_graph = SearchGraph(prompt, config) + result = search_graph.run() + + # Assert + assert result == "No answer found." + + @patch('scrapegraphai.graphs.search_graph.SearchInternetNode') + @patch('scrapegraphai.graphs.search_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.search_graph.MergeAnswersNode') + @patch('scrapegraphai.graphs.search_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet): + """ + Test that the max_results parameter from the config is correctly passed to the SearchInternetNode. + """ + # Arrange + prompt = "Test prompt" + max_results = 5 + config = {"llm": {"model": "test-model"}, "max_results": max_results} + + # Act + search_graph = SearchGraph(prompt, config) + + # Assert + mock_search_internet.assert_called_once() + call_args = mock_search_internet.call_args + assert call_args.kwargs['node_config']['max_results'] == max_results \ No newline at end of file