diff --git a/tests/test_depth_search_graph.py b/tests/test_depth_search_graph.py new file mode 100644 index 00000000..0197a6b8 --- /dev/null +++ b/tests/test_depth_search_graph.py @@ -0,0 +1,34 @@ +from unittest.mock import patch, MagicMock +from scrapegraphai.graphs.depth_search_graph import DepthSearchGraph +from scrapegraphai.graphs.abstract_graph import AbstractGraph +import pytest + + +class TestDepthSearchGraph: + """Test suite for DepthSearchGraph class""" + + @pytest.mark.parametrize( + "source, expected_input_key", + [ + ("https://example.com", "url"), + ("/path/to/local/directory", "local_dir"), + ], + ) + def test_depth_search_graph_initialization(self, source, expected_input_key): + """ + Test that DepthSearchGraph initializes correctly with different source types. + This test verifies that the input_key is set to 'url' for web sources and + 'local_dir' for local directory sources. + """ + prompt = "Test prompt" + config = {"llm": {"model": "mock_model"}} + + # Mock both BaseGraph and _create_llm method + with patch("scrapegraphai.graphs.depth_search_graph.BaseGraph"), \ + patch.object(AbstractGraph, '_create_llm', return_value=MagicMock()): + graph = DepthSearchGraph(prompt, source, config) + + assert graph.prompt == prompt + assert graph.source == source + assert graph.config == config + assert graph.input_key == expected_input_key diff --git a/tests/test_json_scraper_graph.py b/tests/test_json_scraper_graph.py index 1572650e..2abcaa85 100644 --- a/tests/test_json_scraper_graph.py +++ b/tests/test_json_scraper_graph.py @@ -1,6 +1,6 @@ import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from scrapegraphai.graphs.json_scraper_graph import JSONScraperGraph from unittest.mock import Mock, patch @@ -133,4 +133,60 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate mock_execute.assert_called_once_with({"user_prompt": "Query that produces no answer", "json": "path/to/empty/file.json"}) mock_fetch_node.assert_called_once() mock_generate_answer_node.assert_called_once() + mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) + + @pytest.fixture + def mock_llm_model(self): + return Mock() + + @pytest.fixture + def mock_embedder_model(self): + return Mock() + + @patch('scrapegraphai.graphs.json_scraper_graph.FetchNode') + @patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode') + @patch.object(JSONScraperGraph, '_create_llm') + def test_json_scraper_graph_with_custom_schema(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model): + """ + Test JSONScraperGraph with a custom schema. + This test checks if the graph correctly handles a custom schema input + and passes it to the GenerateAnswerNode. + """ + # Define a custom schema + class CustomSchema(BaseModel): + name: str = Field(..., description="Name of the attraction") + description: str = Field(..., description="Description of the attraction") + + # Mock the _create_llm method to return a mock LLM model + mock_create_llm.return_value = mock_llm_model + + # Mock the execute method of BaseGraph + with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute: + mock_execute.return_value = ({"answer": "Mocked answer with custom schema"}, {}) + + # Create a JSONScraperGraph instance with a custom schema + graph = JSONScraperGraph( + prompt="List attractions in Chioggia", + source="path/to/chioggia.json", + config={"llm": {"model": "test-model", "temperature": 0}}, + schema=CustomSchema + ) + + # Set mocked embedder model + graph.embedder_model = mock_embedder_model + + # Run the graph + result = graph.run() + + # Assertions + assert result == "Mocked answer with custom schema" + assert graph.input_key == "json" + mock_execute.assert_called_once_with({"user_prompt": "List attractions in Chioggia", "json": "path/to/chioggia.json"}) + mock_fetch_node.assert_called_once() + mock_generate_answer_node.assert_called_once() + + # Check if the custom schema was passed to GenerateAnswerNode + generate_answer_node_call = mock_generate_answer_node.call_args[1] + assert generate_answer_node_call['node_config']['schema'] == CustomSchema + mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0}) \ No newline at end of file diff --git a/tests/test_search_graph.py b/tests/test_search_graph.py index 0b8209c0..099385da 100644 --- a/tests/test_search_graph.py +++ b/tests/test_search_graph.py @@ -79,4 +79,29 @@ def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_a # Assert mock_search_internet.assert_called_once() call_args = mock_search_internet.call_args - assert call_args.kwargs['node_config']['max_results'] == max_results \ No newline at end of file + assert call_args.kwargs['node_config']['max_results'] == max_results + + @patch('scrapegraphai.graphs.search_graph.SearchInternetNode') + @patch('scrapegraphai.graphs.search_graph.GraphIteratorNode') + @patch('scrapegraphai.graphs.search_graph.MergeAnswersNode') + @patch('scrapegraphai.graphs.search_graph.BaseGraph') + @patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm') + def test_custom_search_engine_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet): + """ + Test that the custom search_engine parameter from the config is correctly passed to the SearchInternetNode. + """ + # Arrange + prompt = "Test prompt" + custom_search_engine = "custom_engine" + config = { + "llm": {"model": "test-model"}, + "search_engine": custom_search_engine + } + + # Act + search_graph = SearchGraph(prompt, config) + + # Assert + mock_search_internet.assert_called_once() + call_args = mock_search_internet.call_args + assert call_args.kwargs['node_config']['search_engine'] == custom_search_engine \ No newline at end of file