diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b9d9be..4feb52b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## [1.3.0-beta.1](https://github.com/ScrapeGraphAI/langchain-scrapegraph/compare/v1.2.1-beta.1...v1.3.0-beta.1) (2025-02-22) + + +### Features + +* searchscraper ([6a96801](https://github.com/ScrapeGraphAI/langchain-scrapegraph/commit/6a968015d9c8f4ce798111850b0f000c3317c467)) +* updated tests searchscraper ([a771564](https://github.com/ScrapeGraphAI/langchain-scrapegraph/commit/a771564838b637f6aef0277e5ca3d723208d6701)) + +## [1.2.1-beta.1](https://github.com/ScrapeGraphAI/langchain-scrapegraph/compare/v1.2.0...v1.2.1-beta.1) (2025-01-02) + + +### Bug Fixes + +* updated docs url ([f7b640c](https://github.com/ScrapeGraphAI/langchain-scrapegraph/commit/f7b640c29d9780a30212acb19b09247b765a41ff)) + ## [1.2.0](https://github.com/ScrapeGraphAI/langchain-scrapegraph/compare/v1.1.0...v1.2.0) (2024-12-18) diff --git a/README.md b/README.md index 03e9746..6e29639 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Python Support](https://img.shields.io/pypi/pyversions/langchain-scrapegraph.svg)](https://pypi.org/project/langchain-scrapegraph/) -[![Documentation](https://img.shields.io/badge/Documentation-Latest-green)](https://scrapegraphai.com/documentation) +[![Documentation](https://img.shields.io/badge/Documentation-Latest-green)](https://docs.scrapegraphai.com/integrations/langchain) Supercharge your LangChain agents with AI-powered web scraping capabilities. LangChain-ScrapeGraph provides a seamless integration between [LangChain](https://github.com/langchain-ai/langchain) and [ScrapeGraph AI](https://scrapegraphai.com), enabling your agents to extract structured data from websites using natural language. @@ -58,98 +58,76 @@ result = tool.invoke({ print(result) ``` -
-🔍 Using Output Schemas with SmartscraperTool - -You can define the structure of the output using Pydantic models: +### 🌐 SearchscraperTool +Search and extract structured information from the web using natural language prompts. ```python -from typing import List -from pydantic import BaseModel, Field -from langchain_scrapegraph.tools import SmartScraperTool +from langchain_scrapegraph.tools import SearchScraperTool -class WebsiteInfo(BaseModel): - title: str = Field(description="The main title of the webpage") - description: str = Field(description="The main description or first paragraph") - urls: List[str] = Field(description="The URLs inside the webpage") - -# Initialize with schema -tool = SmartScraperTool(llm_output_schema=WebsiteInfo) +# Initialize the tool (uses SGAI_API_KEY from environment) +tool = SearchScraperTool() -# The output will conform to the WebsiteInfo schema +# Search and extract information using natural language result = tool.invoke({ - "website_url": "https://www.example.com", - "user_prompt": "Extract the website information" + "user_prompt": "What are the key features and pricing of ChatGPT Plus?" }) print(result) # { -# "title": "Example Domain", -# "description": "This domain is for use in illustrative examples...", -# "urls": ["https://www.iana.org/domains/example"] +# "product": { +# "name": "ChatGPT Plus", +# "description": "Premium version of ChatGPT..." +# }, +# "features": [...], +# "pricing": {...}, +# "reference_urls": [ +# "https://openai.com/chatgpt", +# ... +# ] # } ``` -
- -### 💻 LocalscraperTool -Extract information from HTML content using AI. - -```python -from langchain_scrapegraph.tools import LocalScraperTool - -tool = LocalScraperTool() -result = tool.invoke({ - "user_prompt": "Extract all contact information", - "website_html": "..." -}) - -print(result) -```
-🔍 Using Output Schemas with LocalscraperTool +🔍 Using Output Schemas with SearchscraperTool You can define the structure of the output using Pydantic models: ```python -from typing import Optional +from typing import List, Dict from pydantic import BaseModel, Field -from langchain_scrapegraph.tools import LocalScraperTool +from langchain_scrapegraph.tools import SearchScraperTool -class CompanyInfo(BaseModel): - name: str = Field(description="The company name") - description: str = Field(description="The company description") - email: Optional[str] = Field(description="Contact email if available") - phone: Optional[str] = Field(description="Contact phone if available") +class ProductInfo(BaseModel): + name: str = Field(description="Product name") + features: List[str] = Field(description="List of product features") + pricing: Dict[str, Any] = Field(description="Pricing information") + reference_urls: List[str] = Field(description="Source URLs for the information") # Initialize with schema -tool = LocalScraperTool(llm_output_schema=CompanyInfo) - -html_content = """ - - -

TechCorp Solutions

-

We are a leading AI technology company.

-
-

Email: contact@techcorp.com

-

Phone: (555) 123-4567

-
- - -""" - -# The output will conform to the CompanyInfo schema +tool = SearchScraperTool(llm_output_schema=ProductInfo) + +# The output will conform to the ProductInfo schema result = tool.invoke({ - "website_html": html_content, - "user_prompt": "Extract the company information" + "user_prompt": "What are the key features and pricing of ChatGPT Plus?" }) print(result) # { -# "name": "TechCorp Solutions", -# "description": "We are a leading AI technology company.", -# "email": "contact@techcorp.com", -# "phone": "(555) 123-4567" +# "name": "ChatGPT Plus", +# "features": [ +# "GPT-4 access", +# "Faster response speed", +# ... +# ], +# "pricing": { +# "amount": 20, +# "currency": "USD", +# "period": "monthly" +# }, +# "reference_urls": [ +# "https://openai.com/chatgpt", +# ... +# ] # } ```
diff --git a/examples/agent_example.py b/examples/agent_example.py index 9e61fba..5f461c8 100644 --- a/examples/agent_example.py +++ b/examples/agent_example.py @@ -11,7 +11,7 @@ from langchain_scrapegraph.tools import ( GetCreditsTool, - LocalScraperTool, + SearchScraperTool, SmartScraperTool, ) @@ -20,8 +20,8 @@ # Initialize the tools tools = [ SmartScraperTool(), - LocalScraperTool(), GetCreditsTool(), + SearchScraperTool(), ] # Create the prompt template diff --git a/examples/localscraper_tool.py b/examples/localscraper_tool.py deleted file mode 100644 index a8df8ee..0000000 --- a/examples/localscraper_tool.py +++ /dev/null @@ -1,28 +0,0 @@ -from scrapegraph_py.logger import sgai_logger - -from langchain_scrapegraph.tools import LocalScraperTool - -sgai_logger.set_logging(level="INFO") - -# Will automatically get SGAI_API_KEY from environment -tool = LocalScraperTool() - -# Example website and prompt -html_content = """ - - -

Company Name

-

We are a technology company focused on AI solutions.

-
-

Email: contact@example.com

-

Phone: (555) 123-4567

-
- - -""" -user_prompt = "Make a summary of the webpage and extract the email and phone number" - -# Use the tool -result = tool.invoke({"website_html": html_content, "user_prompt": user_prompt}) - -print(result) diff --git a/examples/localscraper_tool_schema.py b/examples/localscraper_tool_schema.py deleted file mode 100644 index 85f3ab9..0000000 --- a/examples/localscraper_tool_schema.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List - -from pydantic import BaseModel, Field -from scrapegraph_py.logger import sgai_logger - -from langchain_scrapegraph.tools import LocalScraperTool - - -class WebsiteInfo(BaseModel): - title: str = Field(description="The main title of the webpage") - description: str = Field(description="The main description or first paragraph") - urls: List[str] = Field(description="The URLs inside the webpage") - - -sgai_logger.set_logging(level="INFO") - -# Initialize with Pydantic model class -tool = LocalScraperTool(llm_output_schema=WebsiteInfo) - -# Example website and prompt -html_content = """ - - -

Company Name

-

We are a technology company focused on AI solutions.

-
-

Email: contact@example.com

-

Phone: (555) 123-4567

-
- - -""" -user_prompt = "Make a summary of the webpage and extract the email and phone number" - -# Use the tool -result = tool.invoke({"website_html": html_content, "user_prompt": user_prompt}) - -print(result) diff --git a/examples/searchscraper_tool.py b/examples/searchscraper_tool.py new file mode 100644 index 0000000..a14d562 --- /dev/null +++ b/examples/searchscraper_tool.py @@ -0,0 +1,16 @@ +from scrapegraph_py.logger import sgai_logger + +from langchain_scrapegraph.tools import SearchScraperTool + +sgai_logger.set_logging(level="INFO") + +# Will automatically get SGAI_API_KEY from environment +tool = SearchScraperTool() + +# Example prompt +user_prompt = "What are the key features and pricing of ChatGPT Plus?" + +# Use the tool +result = tool.invoke({"user_prompt": user_prompt}) + +print("\nResult:", result) diff --git a/examples/searchscraper_tool_schema.py b/examples/searchscraper_tool_schema.py new file mode 100644 index 0000000..9ada05e --- /dev/null +++ b/examples/searchscraper_tool_schema.py @@ -0,0 +1,41 @@ +from typing import Dict, List + +from pydantic import BaseModel, Field +from scrapegraph_py.logger import sgai_logger + +from langchain_scrapegraph.tools import SearchScraperTool + + +class Feature(BaseModel): + name: str = Field(description="Name of the feature") + description: str = Field(description="Description of the feature") + + +class PricingPlan(BaseModel): + name: str = Field(description="Name of the pricing plan") + price: Dict[str, str] = Field( + description="Price details including amount, currency, and period" + ) + features: List[str] = Field(description="List of features included in the plan") + + +class ProductInfo(BaseModel): + name: str = Field(description="Name of the product") + description: str = Field(description="Description of the product") + features: List[Feature] = Field(description="List of product features") + pricing: Dict[str, List[PricingPlan]] = Field(description="Pricing information") + reference_urls: List[str] = Field(description="Source URLs for the information") + + +sgai_logger.set_logging(level="INFO") + +# Initialize with Pydantic model class +tool = SearchScraperTool(llm_output_schema=ProductInfo) + +# Example prompt +user_prompt = "What are the key features and pricing of ChatGPT Plus?" + +# Use the tool - output will conform to ProductInfo schema +result = tool.invoke({"user_prompt": user_prompt}) + +print("\nResult:", result) diff --git a/examples/smartscraper_tool_schema.py b/examples/smartscraper_tool_schema.py index bded746..3220881 100644 --- a/examples/smartscraper_tool_schema.py +++ b/examples/smartscraper_tool_schema.py @@ -17,10 +17,31 @@ class WebsiteInfo(BaseModel): # Initialize with Pydantic model class tool = SmartScraperTool(llm_output_schema=WebsiteInfo) -# Example website and prompt +# Example 1: Using website URL website_url = "https://www.example.com" user_prompt = "Extract info about the website" -# Use the tool - output will conform to WebsiteInfo schema -result = tool.invoke({"website_url": website_url, "user_prompt": user_prompt}) -print(result) +# Use the tool with URL +result_url = tool.invoke({"website_url": website_url, "user_prompt": user_prompt}) +print("\nResult from URL:", result_url) + +# Example 2: Using HTML content directly +html_content = """ + + +

Example Domain

+

This domain is for use in illustrative examples.

+ More information... + + +""" + +# Use the tool with HTML content +result_html = tool.invoke( + { + "website_url": website_url, # Still required but will be overridden + "website_html": html_content, + "user_prompt": user_prompt, + } +) +print("\nResult from HTML:", result_html) diff --git a/langchain_scrapegraph/tools/__init__.py b/langchain_scrapegraph/tools/__init__.py index a61f301..aa9e0b2 100644 --- a/langchain_scrapegraph/tools/__init__.py +++ b/langchain_scrapegraph/tools/__init__.py @@ -1,6 +1,6 @@ from .credits import GetCreditsTool -from .localscraper import LocalScraperTool from .markdownify import MarkdownifyTool +from .searchscraper import SearchScraperTool from .smartscraper import SmartScraperTool -__all__ = ["SmartScraperTool", "GetCreditsTool", "MarkdownifyTool", "LocalScraperTool"] +__all__ = ["SmartScraperTool", "GetCreditsTool", "MarkdownifyTool", "SearchScraperTool"] diff --git a/langchain_scrapegraph/tools/localscraper.py b/langchain_scrapegraph/tools/searchscraper.py similarity index 55% rename from langchain_scrapegraph/tools/localscraper.py rename to langchain_scrapegraph/tools/searchscraper.py index 926d6fd..26e3b22 100644 --- a/langchain_scrapegraph/tools/localscraper.py +++ b/langchain_scrapegraph/tools/searchscraper.py @@ -10,15 +10,14 @@ from scrapegraph_py import Client -class LocalscraperInput(BaseModel): +class SearchScraperInput(BaseModel): user_prompt: str = Field( - description="Prompt describing what to extract from the webpage and how to structure the output" + description="Prompt describing what information to search for and extract from the web" ) - website_html: str = Field(description="HTML of the webpage to extract data from") -class LocalScraperTool(BaseTool): - """Tool for extracting structured data from a local HTML file using ScrapeGraph AI. +class SearchScraperTool(BaseTool): + """Tool for searching and extracting structured data from the web using ScrapeGraph AI. Setup: Install ``langchain-scrapegraph`` python package: @@ -43,68 +42,60 @@ class LocalScraperTool(BaseTool): Instantiate: .. code-block:: python - from langchain_scrapegraph.tools import LocalScraperTool + from langchain_scrapegraph.tools import SearchScraperTool # Will automatically get SGAI_API_KEY from environment - tool = LocalScraperTool() + tool = SearchScraperTool() # Or provide API key directly - tool = LocalScraperTool(api_key="your-api-key") + tool = SearchScraperTool(api_key="your-api-key") # Optionally, you can provide an output schema: from pydantic import BaseModel, Field + from typing import List - class CompanyInfo(BaseModel): - name: str = Field(description="Company name") - description: str = Field(description="Company description") - email: str = Field(description="Contact email") + class ProductInfo(BaseModel): + name: str = Field(description="Product name") + features: List[str] = Field(description="List of product features") + pricing: Dict[str, Any] = Field(description="Pricing information") - tool_with_schema = LocalScraperTool(llm_output_schema=CompanyInfo) + tool_with_schema = SearchScraperTool(llm_output_schema=ProductInfo) Use the tool: .. code-block:: python - html_content = ''' - - -

Company Name

-

We are a technology company focused on AI solutions.

-
-

Email: contact@example.com

-

Phone: (555) 123-4567

-
- - - ''' - result = tool.invoke({ - "user_prompt": "Extract company description and contact info", - "website_html": html_content + "user_prompt": "What are the key features and pricing of ChatGPT Plus?" }) print(result) - # Without schema: - # { - # "description": "We are a technology company focused on AI solutions", - # "contact": { - # "email": "contact@example.com", - # "phone": "(555) 123-4567" - # } - # } - # - # With CompanyInfo schema: # { - # "name": "Company Name", - # "description": "We are a technology company focused on AI solutions", - # "email": "contact@example.com" + # "product": { + # "name": "ChatGPT Plus", + # "description": "Premium version of ChatGPT...", + # ... + # }, + # "features": [...], + # "pricing": {...}, + # "reference_urls": [ + # "https://openai.com/chatgpt", + # ... + # ] # } + + Async usage: + .. code-block:: python + + result = await tool.ainvoke({ + "user_prompt": "What are the key features of Product X?" + }) """ - name: str = "LocalScraper" + name: str = "SearchScraper" description: str = ( - "Useful when you need to extract structured data from a HTML webpage, applying also some reasoning using LLM, by providing an HTML string and an extraction prompt" + "Useful when you need to search and extract structured information from the web about a specific topic or query" ) - args_schema: Type[BaseModel] = LocalscraperInput + args_schema: Type[BaseModel] = SearchScraperInput return_direct: bool = True client: Optional[Client] = None api_key: str @@ -124,23 +115,20 @@ def __init__(self, **data: Any): def _run( self, user_prompt: str, - website_html: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> dict: - """Use the tool to extract data from a website.""" + """Use the tool to search and extract data from the web.""" if not self.client: raise ValueError("Client not initialized") if self.llm_output_schema is None: - response = self.client.localscraper( - website_html=website_html, + response = self.client.searchscraper( user_prompt=user_prompt, ) elif isinstance(self.llm_output_schema, type) and issubclass( self.llm_output_schema, BaseModel ): - response = self.client.localscraper( - website_html=website_html, + response = self.client.searchscraper( user_prompt=user_prompt, output_schema=self.llm_output_schema, ) @@ -152,12 +140,10 @@ def _run( async def _arun( self, user_prompt: str, - website_html: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool asynchronously.""" return self._run( user_prompt, - website_html, run_manager=run_manager.get_sync() if run_manager else None, ) diff --git a/langchain_scrapegraph/tools/smartscraper.py b/langchain_scrapegraph/tools/smartscraper.py index 7b07915..60707ec 100644 --- a/langchain_scrapegraph/tools/smartscraper.py +++ b/langchain_scrapegraph/tools/smartscraper.py @@ -15,6 +15,10 @@ class SmartScraperInput(BaseModel): description="Prompt describing what to extract from the webpage and how to structure the output" ) website_url: str = Field(description="Url of the webpage to extract data from") + website_html: Optional[str] = Field( + default=None, + description="Optional HTML content to process instead of fetching from website_url", + ) class SmartScraperTool(BaseTool): @@ -63,11 +67,27 @@ class WebsiteInfo(BaseModel): Use the tool: .. code-block:: python + # Using website URL result = tool.invoke({ "user_prompt": "Extract the main heading and first paragraph", "website_url": "https://example.com" }) + # Using HTML content directly + html_content = ''' + + +

Example Domain

+

This domain is for use in illustrative examples...

+ + + ''' + result = tool.invoke({ + "user_prompt": "Extract the main heading and first paragraph", + "website_url": "https://example.com", + "website_html": html_content # This will override website_url + }) + print(result) # Without schema: # { @@ -115,6 +135,7 @@ def _run( self, user_prompt: str, website_url: str, + website_html: Optional[str] = None, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> dict: """Use the tool to extract data from a website.""" @@ -125,6 +146,7 @@ def _run( response = self.client.smartscraper( website_url=website_url, user_prompt=user_prompt, + website_html=website_html, ) elif isinstance(self.llm_output_schema, type) and issubclass( self.llm_output_schema, BaseModel @@ -132,6 +154,7 @@ def _run( response = self.client.smartscraper( website_url=website_url, user_prompt=user_prompt, + website_html=website_html, output_schema=self.llm_output_schema, ) else: @@ -143,11 +166,13 @@ async def _arun( self, user_prompt: str, website_url: str, + website_html: Optional[str] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool asynchronously.""" return self._run( user_prompt, website_url, + website_html=website_html, run_manager=run_manager.get_sync() if run_manager else None, ) diff --git a/pyproject.toml b/pyproject.toml index 99a268d..acfbddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-scrapegraph" -version = "1.2.0" +version = "1.3.0b1" description = "Library for extracting structured data from websites using ScrapeGraphAI" authors = ["Marco Perini ", "Marco Vinciguerra ", "Lorenzo Padoan "] license = "MIT" diff --git a/tests/integration_tests/test_tools.py b/tests/integration_tests/test_tools.py index 13fbf9d..9ed2a29 100644 --- a/tests/integration_tests/test_tools.py +++ b/tests/integration_tests/test_tools.py @@ -15,7 +15,6 @@ from langchain_scrapegraph.tools import ( GetCreditsTool, - LocalScraperTool, MarkdownifyTool, SmartScraperTool, ) @@ -76,34 +75,3 @@ def tool_constructor_params(self) -> dict: @property def tool_invoke_params_example(self) -> dict: return {"website_url": "https://example.com"} - - -class TestLocalScraperToolIntegration(ToolsIntegrationTests): - @property - def tool_constructor(self) -> Type[LocalScraperTool]: - return LocalScraperTool - - @property - def tool_constructor_params(self) -> dict: - api_key = os.getenv("SGAI_API_KEY") - if not api_key: - pytest.skip("SGAI_API_KEY environment variable not set") - return {"api_key": api_key} - - @property - def tool_invoke_params_example(self) -> dict: - return { - "user_prompt": "Make a summary and extract contact info", - "website_html": """ - - -

Company Name

-

We are a technology company focused on AI solutions.

-
-

Email: contact@example.com

-

Phone: (555) 123-4567

-
- - - """, - } diff --git a/tests/unit_tests/mocks.py b/tests/unit_tests/mocks.py index 740b0d2..7c7f1e9 100644 --- a/tests/unit_tests/mocks.py +++ b/tests/unit_tests/mocks.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Type from langchain_core.tools import BaseTool from pydantic import BaseModel, Field @@ -9,8 +9,25 @@ def __init__(self, api_key: str = None, *args, **kwargs): """Initialize with mock methods that return proper response structures""" self._api_key = api_key - def smartscraper(self, website_url: str, user_prompt: str) -> dict: + def smartscraper( + self, website_url: str, user_prompt: str, website_html: str = None + ) -> dict: """Mock smartscraper method""" + # If website_html is provided, use it to determine the response + if website_html and "

Test

" in website_html: + return { + "request_id": "test-id", + "status": "completed", + "website_url": website_url, + "user_prompt": user_prompt, + "result": { + "main_heading": "Test", + "first_paragraph": "Test paragraph", + }, + "error": "", + } + + # Default response for URL-based requests return { "request_id": "test-id", "status": "completed", @@ -23,6 +40,32 @@ def smartscraper(self, website_url: str, user_prompt: str) -> dict: "error": "", } + def searchscraper(self, user_prompt: str) -> dict: + """Mock searchscraper method""" + return { + "request_id": "test-id", + "status": "completed", + "user_prompt": user_prompt, + "result": { + "product": {"name": "Test Product", "description": "Test description"}, + "features": [{"name": "Feature 1", "description": "Description 1"}], + "pricing": { + "plans": [ + { + "name": "Basic Plan", + "price": { + "amount": "10", + "currency": "USD", + "period": "monthly", + }, + } + ] + }, + }, + "reference_urls": ["https://example.com/test"], + "error": "", + } + def get_credits(self) -> dict: """Mock get_credits method""" return {"remaining_credits": 50, "total_credits_used": 543} @@ -37,19 +80,6 @@ def markdownify(self, website_url: str) -> dict: "error": "", } - def localscraper(self, website_html: str, user_prompt: str) -> dict: - """Mock localscraper method""" - return { - "request_id": "test-id", - "status": "completed", - "user_prompt": user_prompt, - "result": { - "summary": "This is a technology company", - "contact": {"email": "contact@example.com", "phone": "(555) 123-4567"}, - }, - "error": "", - } - def close(self) -> None: """Mock close method""" pass @@ -60,13 +90,12 @@ class MockSmartScraperInput(BaseModel): website_url: str = Field(description="Test URL") -class MockMarkdownifyInput(BaseModel): - website_url: str = Field(description="Test URL") +class MockSearchScraperInput(BaseModel): + user_prompt: str = Field(description="Test prompt") -class MockLocalScraperInput(BaseModel): - user_prompt: str = Field(description="Test prompt") - website_html: str = Field(description="Test HTML") +class MockMarkdownifyInput(BaseModel): + website_url: str = Field(description="Test URL") class MockSmartScraperTool(BaseTool): @@ -80,6 +109,22 @@ def _run(self, **kwargs: Any) -> Dict: return {"main_heading": "Test", "first_paragraph": "Test"} +class MockSearchScraperTool(BaseTool): + name: str = "SearchScraper" + description: str = "Test description" + args_schema: type[BaseModel] = MockSearchScraperInput + client: Optional[MockClient] = None + api_key: str + llm_output_schema: Optional[Type[BaseModel]] = None + + def _run(self, **kwargs: Any) -> Dict: + return { + "product": {"name": "Test Product", "description": "Test description"}, + "features": [{"name": "Feature 1", "description": "Description 1"}], + "reference_urls": ["https://example.com/test"], + } + + class MockGetCreditsTool(BaseTool): name: str = "GetCredits" description: str = "Test description" @@ -99,17 +144,3 @@ class MockMarkdownifyTool(BaseTool): def _run(self, **kwargs: Any) -> str: return "# Example Domain\n\nTest paragraph" - - -class MockLocalScraperTool(BaseTool): - name: str = "LocalScraper" - description: str = "Test description" - args_schema: type[BaseModel] = MockLocalScraperInput - client: Optional[MockClient] = None - api_key: str - - def _run(self, **kwargs: Any) -> Dict: - return { - "summary": "This is a technology company", - "contact": {"email": "contact@example.com", "phone": "(555) 123-4567"}, - } diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index 2ac0876..ccddd92 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -5,15 +5,15 @@ from langchain_scrapegraph.tools import ( GetCreditsTool, - LocalScraperTool, MarkdownifyTool, + SearchScraperTool, SmartScraperTool, ) from tests.unit_tests.mocks import ( MockClient, MockGetCreditsTool, - MockLocalScraperTool, MockMarkdownifyTool, + MockSearchScraperTool, MockSmartScraperTool, ) @@ -36,49 +36,90 @@ def tool_invoke_params_example(self) -> dict: } -class TestGetCreditsToolUnit(ToolsUnitTests): +class TestSmartScraperToolCustom: + def test_invoke_with_html(self): + """Test invoking the tool with HTML content.""" + with patch("langchain_scrapegraph.tools.smartscraper.Client", MockClient): + tool = MockSmartScraperTool(api_key="sgai-test-api-key") + result = tool.invoke( + { + "user_prompt": "Extract the main heading", + "website_url": "https://example.com", + "website_html": "

Test

", + } + ) + assert isinstance(result, dict) + assert "main_heading" in result + assert result["main_heading"] == "Test" + + +class TestSearchScraperToolUnit(ToolsUnitTests): @property - def tool_constructor(self) -> Type[GetCreditsTool]: - return MockGetCreditsTool + def tool_constructor(self) -> Type[SearchScraperTool]: + return MockSearchScraperTool @property def tool_constructor_params(self) -> dict: - with patch("langchain_scrapegraph.tools.credits.Client", MockClient): + with patch("langchain_scrapegraph.tools.searchscraper.Client", MockClient): return {"api_key": "sgai-test-api-key"} @property def tool_invoke_params_example(self) -> dict: - return {} + return { + "user_prompt": "What are the key features of Product X?", + } -class TestMarkdownifyToolUnit(ToolsUnitTests): +class TestSearchScraperToolCustom: + def test_invoke_with_schema(self): + """Test invoking the tool with a schema.""" + from typing import List + + from pydantic import BaseModel, Field + + class TestSchema(BaseModel): + product: dict = Field(description="Product information") + features: List[dict] = Field(description="List of features") + reference_urls: List[str] = Field(description="Reference URLs") + + with patch("langchain_scrapegraph.tools.searchscraper.Client", MockClient): + tool = MockSearchScraperTool(api_key="sgai-test-api-key") + tool.llm_output_schema = TestSchema + result = tool.invoke( + {"user_prompt": "What are the key features of Product X?"} + ) + assert isinstance(result, dict) + assert "product" in result + assert "features" in result + assert "reference_urls" in result + assert isinstance(result["reference_urls"], list) + + +class TestGetCreditsToolUnit(ToolsUnitTests): @property - def tool_constructor(self) -> Type[MarkdownifyTool]: - return MockMarkdownifyTool + def tool_constructor(self) -> Type[GetCreditsTool]: + return MockGetCreditsTool @property def tool_constructor_params(self) -> dict: - with patch("langchain_scrapegraph.tools.markdownify.Client", MockClient): + with patch("langchain_scrapegraph.tools.credits.Client", MockClient): return {"api_key": "sgai-test-api-key"} @property def tool_invoke_params_example(self) -> dict: - return {"website_url": "https://example.com"} + return {} -class TestLocalScraperToolUnit(ToolsUnitTests): +class TestMarkdownifyToolUnit(ToolsUnitTests): @property - def tool_constructor(self) -> Type[LocalScraperTool]: - return MockLocalScraperTool + def tool_constructor(self) -> Type[MarkdownifyTool]: + return MockMarkdownifyTool @property def tool_constructor_params(self) -> dict: - with patch("langchain_scrapegraph.tools.localscraper.Client", MockClient): + with patch("langchain_scrapegraph.tools.markdownify.Client", MockClient): return {"api_key": "sgai-test-api-key"} @property def tool_invoke_params_example(self) -> dict: - return { - "user_prompt": "Make a summary and extract contact info", - "website_html": "

Test

", - } + return {"website_url": "https://example.com"}