From b552aa902fb4f5052468148851434062d8e74b94 Mon Sep 17 00:00:00 2001
From: souvik03-136 <66234771+souvik03-136@users.noreply.github.com>
Date: Mon, 28 Apr 2025 21:09:13 +0530
Subject: [PATCH] feat: enhance error handling and validation across utility
modules
- Add Pydantic models for state validation in code_error_analysis.py and code_error_correction.py
- Implement comprehensive key existence checks to prevent KeyError exceptions
- Create custom exception hierarchy for better error management
- Add improved PDF detection with regex pattern matching in research_web.py
- Implement input validation for all public functions
- Add detailed error messages and type hints
---
scrapegraphai/utils/code_error_analysis.py | 284 +++++++++--
scrapegraphai/utils/code_error_correction.py | 261 ++++++++--
scrapegraphai/utils/research_web.py | 496 +++++++++++++++----
3 files changed, 838 insertions(+), 203 deletions(-)
diff --git a/scrapegraphai/utils/code_error_analysis.py b/scrapegraphai/utils/code_error_analysis.py
index 673c0dfe..f0642cac 100644
--- a/scrapegraphai/utils/code_error_analysis.py
+++ b/scrapegraphai/utils/code_error_analysis.py
@@ -12,8 +12,9 @@
"""
import json
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field, validator
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
@@ -25,7 +26,77 @@
)
-def syntax_focused_analysis(state: dict, llm_model) -> str:
+class AnalysisError(Exception):
+ """Base exception for code analysis errors."""
+ pass
+
+
+class InvalidStateError(AnalysisError):
+ """Exception raised when state dictionary is missing required keys."""
+ pass
+
+
+class CodeAnalysisState(BaseModel):
+ """Base model for code analysis state validation."""
+ generated_code: str = Field(..., description="The generated code to analyze")
+ errors: Dict[str, Any] = Field(..., description="Dictionary containing error information")
+
+ @validator('errors')
+ def validate_errors(cls, v):
+ """Ensure errors dictionary has expected structure."""
+ if not isinstance(v, dict):
+ raise ValueError("errors must be a dictionary")
+ return v
+
+
+class ExecutionAnalysisState(CodeAnalysisState):
+ """Model for execution analysis state validation."""
+ html_code: Optional[str] = Field(None, description="HTML code if available")
+ html_analysis: Optional[str] = Field(None, description="Analysis of HTML code")
+
+ @validator('errors')
+ def validate_execution_errors(cls, v):
+ """Ensure errors dictionary contains execution key."""
+ super().validate_errors(v)
+ if 'execution' not in v:
+ raise ValueError("errors dictionary must contain 'execution' key")
+ return v
+
+
+class ValidationAnalysisState(CodeAnalysisState):
+ """Model for validation analysis state validation."""
+ json_schema: Dict[str, Any] = Field(..., description="JSON schema for validation")
+ execution_result: Any = Field(..., description="Result of code execution")
+
+ @validator('errors')
+ def validate_validation_errors(cls, v):
+ """Ensure errors dictionary contains validation key."""
+ super().validate_errors(v)
+ if 'validation' not in v:
+ raise ValueError("errors dictionary must contain 'validation' key")
+ return v
+
+
+def get_optimal_analysis_template(error_type: str) -> str:
+ """
+ Returns the optimal prompt template based on the error type.
+
+ Args:
+ error_type (str): Type of error to analyze.
+
+ Returns:
+ str: The prompt template text.
+ """
+ template_registry = {
+ "syntax": TEMPLATE_SYNTAX_ANALYSIS,
+ "execution": TEMPLATE_EXECUTION_ANALYSIS,
+ "validation": TEMPLATE_VALIDATION_ANALYSIS,
+ "semantic": TEMPLATE_SEMANTIC_ANALYSIS,
+ }
+ return template_registry.get(error_type, TEMPLATE_SYNTAX_ANALYSIS)
+
+
+def syntax_focused_analysis(state: Dict[str, Any], llm_model) -> str:
"""
Analyzes the syntax errors in the generated code.
@@ -35,17 +106,48 @@ def syntax_focused_analysis(state: dict, llm_model) -> str:
Returns:
str: The result of the syntax error analysis.
+
+ Raises:
+ InvalidStateError: If state is missing required keys.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'print("Hello World")',
+ 'errors': {'syntax': 'Missing parenthesis'}
+ }
+ >>> analysis = syntax_focused_analysis(state, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_SYNTAX_ANALYSIS, input_variables=["generated_code", "errors"]
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {"generated_code": state["generated_code"], "errors": state["errors"]["syntax"]}
- )
+ try:
+ # Validate state using Pydantic model
+ validated_state = CodeAnalysisState(
+ generated_code=state.get("generated_code", ""),
+ errors=state.get("errors", {})
+ )
+
+ # Check if syntax errors exist
+ if "syntax" not in validated_state.errors:
+ raise InvalidStateError("No syntax errors found in state dictionary")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_analysis_template("syntax"),
+ input_variables=["generated_code", "errors"]
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
+ "generated_code": validated_state.generated_code,
+ "errors": validated_state.errors["syntax"]
+ })
+
+ except KeyError as e:
+ raise InvalidStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise AnalysisError(f"Syntax analysis failed: {str(e)}")
-def execution_focused_analysis(state: dict, llm_model) -> str:
+def execution_focused_analysis(state: Dict[str, Any], llm_model) -> str:
"""
Analyzes the execution errors in the generated code and HTML code.
@@ -55,23 +157,50 @@ def execution_focused_analysis(state: dict, llm_model) -> str:
Returns:
str: The result of the execution error analysis.
- """
- prompt = PromptTemplate(
- template=TEMPLATE_EXECUTION_ANALYSIS,
- input_variables=["generated_code", "errors", "html_code", "html_analysis"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {
- "generated_code": state["generated_code"],
- "errors": state["errors"]["execution"],
- "html_code": state["html_code"],
- "html_analysis": state["html_analysis"],
+
+ Raises:
+ InvalidStateError: If state is missing required keys.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'print(x)',
+ 'errors': {'execution': 'NameError: name "x" is not defined'},
+ 'html_code': '
Test
',
+ 'html_analysis': 'Valid HTML'
}
- )
+ >>> analysis = execution_focused_analysis(state, mock_llm)
+ """
+ try:
+ # Validate state using Pydantic model
+ validated_state = ExecutionAnalysisState(
+ generated_code=state.get("generated_code", ""),
+ errors=state.get("errors", {}),
+ html_code=state.get("html_code", ""),
+ html_analysis=state.get("html_analysis", "")
+ )
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_analysis_template("execution"),
+ input_variables=["generated_code", "errors", "html_code", "html_analysis"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
+ "generated_code": validated_state.generated_code,
+ "errors": validated_state.errors["execution"],
+ "html_code": validated_state.html_code,
+ "html_analysis": validated_state.html_analysis,
+ })
+
+ except KeyError as e:
+ raise InvalidStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise AnalysisError(f"Execution analysis failed: {str(e)}")
-def validation_focused_analysis(state: dict, llm_model) -> str:
+def validation_focused_analysis(state: Dict[str, Any], llm_model) -> str:
"""
Analyzes the validation errors in the generated code based on a JSON schema.
@@ -82,24 +211,51 @@ def validation_focused_analysis(state: dict, llm_model) -> str:
Returns:
str: The result of the validation error analysis.
- """
- prompt = PromptTemplate(
- template=TEMPLATE_VALIDATION_ANALYSIS,
- input_variables=["generated_code", "errors", "json_schema", "execution_result"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {
- "generated_code": state["generated_code"],
- "errors": state["errors"]["validation"],
- "json_schema": state["json_schema"],
- "execution_result": state["execution_result"],
+
+ Raises:
+ InvalidStateError: If state is missing required keys.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'return {"name": "John"}',
+ 'errors': {'validation': 'Missing required field: age'},
+ 'json_schema': {'required': ['name', 'age']},
+ 'execution_result': {'name': 'John'}
}
- )
+ >>> analysis = validation_focused_analysis(state, mock_llm)
+ """
+ try:
+ # Validate state using Pydantic model
+ validated_state = ValidationAnalysisState(
+ generated_code=state.get("generated_code", ""),
+ errors=state.get("errors", {}),
+ json_schema=state.get("json_schema", {}),
+ execution_result=state.get("execution_result", {})
+ )
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_analysis_template("validation"),
+ input_variables=["generated_code", "errors", "json_schema", "execution_result"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
+ "generated_code": validated_state.generated_code,
+ "errors": validated_state.errors["validation"],
+ "json_schema": validated_state.json_schema,
+ "execution_result": validated_state.execution_result,
+ })
+
+ except KeyError as e:
+ raise InvalidStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise AnalysisError(f"Validation analysis failed: {str(e)}")
def semantic_focused_analysis(
- state: dict, comparison_result: Dict[str, Any], llm_model
+ state: Dict[str, Any], comparison_result: Dict[str, Any], llm_model
) -> str:
"""
Analyzes the semantic differences in the generated code based on a comparison result.
@@ -112,16 +268,48 @@ def semantic_focused_analysis(
Returns:
str: The result of the semantic error analysis.
+
+ Raises:
+ InvalidStateError: If state or comparison_result is missing required keys.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'def add(a, b): return a + b'
+ }
+ >>> comparison_result = {
+ 'differences': ['Missing docstring', 'No type hints'],
+ 'explanation': 'The code is missing documentation'
+ }
+ >>> analysis = semantic_focused_analysis(state, comparison_result, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_SEMANTIC_ANALYSIS,
- input_variables=["generated_code", "differences", "explanation"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {
- "generated_code": state["generated_code"],
+ try:
+ # Validate state using Pydantic model
+ validated_state = CodeAnalysisState(
+ generated_code=state.get("generated_code", ""),
+ errors=state.get("errors", {})
+ )
+
+ # Validate comparison_result
+ if "differences" not in comparison_result:
+ raise InvalidStateError("comparison_result missing 'differences' key")
+ if "explanation" not in comparison_result:
+ raise InvalidStateError("comparison_result missing 'explanation' key")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_analysis_template("semantic"),
+ input_variables=["generated_code", "differences", "explanation"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated inputs
+ return chain.invoke({
+ "generated_code": validated_state.generated_code,
"differences": json.dumps(comparison_result["differences"], indent=2),
"explanation": comparison_result["explanation"],
- }
- )
+ })
+
+ except KeyError as e:
+ raise InvalidStateError(f"Missing required key: {e}")
+ except Exception as e:
+ raise AnalysisError(f"Semantic analysis failed: {str(e)}")
\ No newline at end of file
diff --git a/scrapegraphai/utils/code_error_correction.py b/scrapegraphai/utils/code_error_correction.py
index e73237ad..b3838422 100644
--- a/scrapegraphai/utils/code_error_correction.py
+++ b/scrapegraphai/utils/code_error_correction.py
@@ -11,7 +11,10 @@
"""
import json
+from typing import Any, Dict, Optional
+from functools import lru_cache
+from pydantic import BaseModel, Field, validator
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
@@ -23,7 +26,57 @@
)
-def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
+class CodeGenerationError(Exception):
+ """Base exception for code generation errors."""
+ pass
+
+
+class InvalidCorrectionStateError(CodeGenerationError):
+ """Exception raised when state dictionary is missing required keys."""
+ pass
+
+
+class CorrectionState(BaseModel):
+ """Base model for code correction state validation."""
+ generated_code: str = Field(..., description="The original generated code to correct")
+
+ class Config:
+ extra = "allow"
+
+
+class ValidationCorrectionState(CorrectionState):
+ """Model for validation correction state validation."""
+ json_schema: Dict[str, Any] = Field(..., description="JSON schema for validation")
+
+
+class SemanticCorrectionState(CorrectionState):
+ """Model for semantic correction state validation."""
+ execution_result: Any = Field(..., description="Result of code execution")
+ reference_answer: Any = Field(..., description="Reference answer for comparison")
+
+
+@lru_cache(maxsize=32)
+def get_optimal_correction_template(error_type: str) -> str:
+ """
+ Returns the optimal prompt template for code correction based on the error type.
+ Results are cached for performance.
+
+ Args:
+ error_type (str): Type of error to correct.
+
+ Returns:
+ str: The prompt template text.
+ """
+ template_registry = {
+ "syntax": TEMPLATE_SYNTAX_CODE_GENERATION,
+ "execution": TEMPLATE_EXECUTION_CODE_GENERATION,
+ "validation": TEMPLATE_VALIDATION_CODE_GENERATION,
+ "semantic": TEMPLATE_SEMANTIC_CODE_GENERATION,
+ }
+ return template_registry.get(error_type, TEMPLATE_SYNTAX_CODE_GENERATION)
+
+
+def syntax_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str:
"""
Generates corrected code based on syntax error analysis.
@@ -34,18 +87,46 @@ def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str
Returns:
str: The corrected code.
+
+ Raises:
+ InvalidCorrectionStateError: If state is missing required keys.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'print("Hello World"'
+ }
+ >>> analysis = "Missing closing parenthesis in print statement"
+ >>> corrected_code = syntax_focused_code_generation(state, analysis, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_SYNTAX_CODE_GENERATION,
- input_variables=["analysis", "generated_code"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {"analysis": analysis, "generated_code": state["generated_code"]}
- )
+ try:
+ # Validate state using Pydantic model
+ validated_state = CorrectionState(
+ generated_code=state.get("generated_code", "")
+ )
+
+ if not analysis or not isinstance(analysis, str):
+ raise InvalidCorrectionStateError("Analysis must be a non-empty string")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_correction_template("syntax"),
+ input_variables=["analysis", "generated_code"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
+ "analysis": analysis,
+ "generated_code": validated_state.generated_code
+ })
+
+ except KeyError as e:
+ raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise CodeGenerationError(f"Syntax code generation failed: {str(e)}")
-def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
+def execution_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str:
"""
Generates corrected code based on execution error analysis.
@@ -56,18 +137,46 @@ def execution_focused_code_generation(state: dict, analysis: str, llm_model) ->
Returns:
str: The corrected code.
+
+ Raises:
+ InvalidCorrectionStateError: If state is missing required keys or analysis is invalid.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'print(x)'
+ }
+ >>> analysis = "Variable 'x' is not defined before use"
+ >>> corrected_code = execution_focused_code_generation(state, analysis, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_EXECUTION_CODE_GENERATION,
- input_variables=["analysis", "generated_code"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {"analysis": analysis, "generated_code": state["generated_code"]}
- )
+ try:
+ # Validate state using Pydantic model
+ validated_state = CorrectionState(
+ generated_code=state.get("generated_code", "")
+ )
+
+ if not analysis or not isinstance(analysis, str):
+ raise InvalidCorrectionStateError("Analysis must be a non-empty string")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_correction_template("execution"),
+ input_variables=["analysis", "generated_code"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
+ "analysis": analysis,
+ "generated_code": validated_state.generated_code
+ })
+
+ except KeyError as e:
+ raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise CodeGenerationError(f"Execution code generation failed: {str(e)}")
-def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
+def validation_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str:
"""
Generates corrected code based on validation error analysis.
@@ -78,22 +187,49 @@ def validation_focused_code_generation(state: dict, analysis: str, llm_model) ->
Returns:
str: The corrected code.
+
+ Raises:
+ InvalidCorrectionStateError: If state is missing required keys or analysis is invalid.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'return {"name": "John"}',
+ 'json_schema': {'required': ['name', 'age']}
+ }
+ >>> analysis = "The output JSON is missing the required 'age' field"
+ >>> corrected_code = validation_focused_code_generation(state, analysis, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_VALIDATION_CODE_GENERATION,
- input_variables=["analysis", "generated_code", "json_schema"],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {
+ try:
+ # Validate state using Pydantic model
+ validated_state = ValidationCorrectionState(
+ generated_code=state.get("generated_code", ""),
+ json_schema=state.get("json_schema", {})
+ )
+
+ if not analysis or not isinstance(analysis, str):
+ raise InvalidCorrectionStateError("Analysis must be a non-empty string")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_correction_template("validation"),
+ input_variables=["analysis", "generated_code", "json_schema"],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
"analysis": analysis,
- "generated_code": state["generated_code"],
- "json_schema": state["json_schema"],
- }
- )
+ "generated_code": validated_state.generated_code,
+ "json_schema": validated_state.json_schema,
+ })
+
+ except KeyError as e:
+ raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise CodeGenerationError(f"Validation code generation failed: {str(e)}")
-def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> str:
+def semantic_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str:
"""
Generates corrected code based on semantic error analysis.
@@ -104,22 +240,51 @@ def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> s
Returns:
str: The corrected code.
+
+ Raises:
+ InvalidCorrectionStateError: If state is missing required keys or analysis is invalid.
+
+ Example:
+ >>> state = {
+ 'generated_code': 'def add(a, b): return a + b',
+ 'execution_result': {'result': 3},
+ 'reference_answer': {'result': 3, 'documentation': 'Adds two numbers'}
+ }
+ >>> analysis = "The code is missing documentation"
+ >>> corrected_code = semantic_focused_code_generation(state, analysis, mock_llm)
"""
- prompt = PromptTemplate(
- template=TEMPLATE_SEMANTIC_CODE_GENERATION,
- input_variables=[
- "analysis",
- "generated_code",
- "generated_result",
- "reference_result",
- ],
- )
- chain = prompt | llm_model | StrOutputParser()
- return chain.invoke(
- {
+ try:
+ # Validate state using Pydantic model
+ validated_state = SemanticCorrectionState(
+ generated_code=state.get("generated_code", ""),
+ execution_result=state.get("execution_result", {}),
+ reference_answer=state.get("reference_answer", {})
+ )
+
+ if not analysis or not isinstance(analysis, str):
+ raise InvalidCorrectionStateError("Analysis must be a non-empty string")
+
+ # Create prompt template and chain
+ prompt = PromptTemplate(
+ template=get_optimal_correction_template("semantic"),
+ input_variables=[
+ "analysis",
+ "generated_code",
+ "generated_result",
+ "reference_result",
+ ],
+ )
+ chain = prompt | llm_model | StrOutputParser()
+
+ # Execute chain with validated state
+ return chain.invoke({
"analysis": analysis,
- "generated_code": state["generated_code"],
- "generated_result": json.dumps(state["execution_result"], indent=2),
- "reference_result": json.dumps(state["reference_answer"], indent=2),
- }
- )
+ "generated_code": validated_state.generated_code,
+ "generated_result": json.dumps(validated_state.execution_result, indent=2),
+ "reference_result": json.dumps(validated_state.reference_answer, indent=2),
+ })
+
+ except KeyError as e:
+ raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}")
+ except Exception as e:
+ raise CodeGenerationError(f"Semantic code generation failed: {str(e)}")
\ No newline at end of file
diff --git a/scrapegraphai/utils/research_web.py b/scrapegraphai/utils/research_web.py
index b9721306..195e11ca 100644
--- a/scrapegraphai/utils/research_web.py
+++ b/scrapegraphai/utils/research_web.py
@@ -1,27 +1,161 @@
"""
-research_web module
+research_web module for web searching across different search engines with improved
+error handling, validation, and security features.
"""
import re
-from typing import List
+import random
+import time
+from typing import List, Dict, Union, Optional
+from functools import wraps
import requests
from bs4 import BeautifulSoup
+from pydantic import BaseModel, Field, validator
from langchain_community.tools import DuckDuckGoSearchResults
+class ResearchWebError(Exception):
+ """Base exception for research web errors."""
+ pass
+
+
+class SearchConfigError(ResearchWebError):
+ """Exception raised when search configuration is invalid."""
+ pass
+
+
+class SearchRequestError(ResearchWebError):
+ """Exception raised when search request fails."""
+ pass
+
+
+class ProxyConfig(BaseModel):
+ """Model for proxy configuration validation."""
+ server: str = Field(..., description="Proxy server address including port")
+ username: Optional[str] = Field(None, description="Username for proxy authentication")
+ password: Optional[str] = Field(None, description="Password for proxy authentication")
+
+
+class SearchConfig(BaseModel):
+ """Model for search configuration validation."""
+ query: str = Field(..., description="Search query")
+ search_engine: str = Field("duckduckgo", description="Search engine to use")
+ max_results: int = Field(10, description="Maximum number of results to return")
+ port: Optional[int] = Field(8080, description="Port for SearXNG")
+ timeout: int = Field(10, description="Request timeout in seconds")
+ proxy: Optional[Union[str, Dict, ProxyConfig]] = Field(None, description="Proxy configuration")
+ serper_api_key: Optional[str] = Field(None, description="API key for Serper")
+ region: Optional[str] = Field(None, description="Country/region code")
+ language: str = Field("en", description="Language code")
+
+ @validator('search_engine')
+ def validate_search_engine(cls, v):
+ """Validate search engine."""
+ valid_engines = {"duckduckgo", "bing", "searxng", "serper"}
+ if v.lower() not in valid_engines:
+ raise ValueError(f"Search engine must be one of: {', '.join(valid_engines)}")
+ return v.lower()
+
+ @validator('query')
+ def validate_query(cls, v):
+ """Validate search query."""
+ if not v or not isinstance(v, str):
+ raise ValueError("Query must be a non-empty string")
+ return v
+
+ @validator('max_results')
+ def validate_max_results(cls, v):
+ """Validate max results."""
+ if v < 1 or v > 100:
+ raise ValueError("max_results must be between 1 and 100")
+ return v
+
+
+# Define advanced PDF detection regex
+PDF_REGEX = re.compile(r'\.pdf(#.*)?(\?.*)?$', re.IGNORECASE)
+
+
+# Rate limiting decorator
+def rate_limited(calls: int, period: int = 60):
+ """
+ Decorator to limit the rate of function calls.
+
+ Args:
+ calls (int): Maximum number of calls allowed in the period.
+ period (int): Time period in seconds.
+
+ Returns:
+ Callable: Decorated function with rate limiting.
+ """
+ min_interval = period / float(calls)
+ last_called = [0.0]
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ elapsed = time.time() - last_called[0]
+ wait_time = min_interval - elapsed
+ if wait_time > 0:
+ time.sleep(wait_time)
+ result = func(*args, **kwargs)
+ last_called[0] = time.time()
+ return result
+ return wrapper
+ return decorator
+
+
+def sanitize_search_query(query: str) -> str:
+ """
+ Sanitizes search query to prevent injection attacks.
+
+ Args:
+ query (str): The search query.
+
+ Returns:
+ str: Sanitized query.
+ """
+ # Remove potential command injection characters
+ sanitized = re.sub(r'[;&|`$()\[\]{}<>]', '', query)
+ # Trim whitespace
+ sanitized = sanitized.strip()
+ return sanitized
+
+
+# List of user agents for rotation
+USER_AGENTS = [
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15",
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
+ "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36",
+ "Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1"
+]
+
+
+def get_random_user_agent() -> str:
+ """
+ Returns a random user agent from the list.
+
+ Returns:
+ str: Random user agent string.
+ """
+ return random.choice(USER_AGENTS)
+
+
+@rate_limited(calls=10, period=60)
def search_on_web(
query: str,
search_engine: str = "duckduckgo",
max_results: int = 10,
port: int = 8080,
timeout: int = 10,
- proxy: str | dict = None,
- serper_api_key: str = None,
- region: str = None,
+ proxy: Optional[Union[str, Dict, ProxyConfig]] = None,
+ serper_api_key: Optional[str] = None,
+ region: Optional[str] = None,
language: str = "en",
) -> List[str]:
- """Search web function with improved error handling and validation
+ """
+ Search web function with improved error handling, validation, and security features.
Args:
query (str): Search query
@@ -29,147 +163,295 @@ def search_on_web(
max_results (int): Maximum number of results to return
port (int): Port for SearXNG
timeout (int): Request timeout in seconds
- proxy (str | dict): Proxy configuration
+ proxy (str | dict | ProxyConfig): Proxy configuration
serper_api_key (str): API key for Serper
region (str): Country/region code (e.g., 'mx' for Mexico)
language (str): Language code (e.g., 'es' for Spanish)
+
+ Returns:
+ List[str]: List of URLs from search results
+
+ Raises:
+ SearchConfigError: If search configuration is invalid
+ SearchRequestError: If search request fails
+ TimeoutError: If search request times out
"""
-
- # Input validation
- if not query or not isinstance(query, str):
- raise ValueError("Query must be a non-empty string")
-
- search_engine = search_engine.lower()
- valid_engines = {"duckduckgo", "bing", "searxng", "serper"}
- if search_engine not in valid_engines:
- raise ValueError(f"Search engine must be one of: {', '.join(valid_engines)}")
-
- # Format proxy once
- formatted_proxy = None
- if proxy:
- formatted_proxy = format_proxy(proxy)
-
try:
+ # Sanitize query for security
+ sanitized_query = sanitize_search_query(query)
+
+ # Validate search configuration
+ config = SearchConfig(
+ query=sanitized_query,
+ search_engine=search_engine,
+ max_results=max_results,
+ port=port,
+ timeout=timeout,
+ proxy=proxy,
+ serper_api_key=serper_api_key,
+ region=region,
+ language=language
+ )
+
+ # Format proxy once
+ formatted_proxy = None
+ if config.proxy:
+ formatted_proxy = format_proxy(config.proxy)
+
results = []
- if search_engine == "duckduckgo":
+ if config.search_engine == "duckduckgo":
# Create a DuckDuckGo search object with max_results
- research = DuckDuckGoSearchResults(max_results=max_results)
+ research = DuckDuckGoSearchResults(max_results=config.max_results)
# Run the search
- res = research.run(query)
+ res = research.run(config.query)
# Extract URLs using regex
results = re.findall(r"https?://[^\s,\]]+", res)
- elif search_engine == "bing":
- results = _search_bing(query, max_results, timeout, formatted_proxy)
-
- elif search_engine == "searxng":
- results = _search_searxng(query, max_results, port, timeout)
-
- elif search_engine == "serper":
- results = _search_serper(query, max_results, serper_api_key, timeout)
+ elif config.search_engine == "bing":
+ results = _search_bing(
+ config.query,
+ config.max_results,
+ config.timeout,
+ formatted_proxy
+ )
+
+ elif config.search_engine == "searxng":
+ results = _search_searxng(
+ config.query,
+ config.max_results,
+ config.port,
+ config.timeout
+ )
+
+ elif config.search_engine == "serper":
+ results = _search_serper(
+ config.query,
+ config.max_results,
+ config.serper_api_key,
+ config.timeout
+ )
return filter_pdf_links(results)
except requests.Timeout:
raise TimeoutError(f"Search request timed out after {timeout} seconds")
except requests.RequestException as e:
- raise RuntimeError(f"Search request failed: {str(e)}")
+ raise SearchRequestError(f"Search request failed: {str(e)}")
+ except ValueError as e:
+ raise SearchConfigError(f"Invalid search configuration: {str(e)}")
def _search_bing(
- query: str, max_results: int, timeout: int, proxy: str = None
+ query: str, max_results: int, timeout: int, proxy: Optional[str] = None
) -> List[str]:
- """Helper function for Bing search"""
+ """
+ Helper function for Bing search with improved error handling.
+
+ Args:
+ query (str): Search query
+ max_results (int): Maximum number of results to return
+ timeout (int): Request timeout in seconds
+ proxy (str, optional): Proxy configuration
+
+ Returns:
+ List[str]: List of URLs from search results
+ """
headers = {
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
+ "User-Agent": get_random_user_agent()
}
- search_url = f"https://www.bing.com/search?q={query}"
-
+
+ params = {
+ "q": query,
+ "count": max_results
+ }
+
proxies = {"http": proxy, "https": proxy} if proxy else None
- response = requests.get(
- search_url, headers=headers, timeout=timeout, proxies=proxies
- )
- response.raise_for_status()
-
- soup = BeautifulSoup(response.text, "html.parser")
- return [
- result.find("a")["href"]
- for result in soup.find_all("li", class_="b_algo", limit=max_results)
- ]
-
-
-def _search_searxng(query: str, max_results: int, port: int, timeout: int) -> List[str]:
- """Helper function for SearXNG search"""
- url = f"http://localhost:{port}/search"
+
+ try:
+ response = requests.get(
+ "https://www.bing.com/search",
+ params=params,
+ headers=headers,
+ proxies=proxies,
+ timeout=timeout
+ )
+ response.raise_for_status()
+
+ soup = BeautifulSoup(response.text, "html.parser")
+ results = []
+
+ # Extract URLs from Bing search results
+ for link in soup.select("li.b_algo h2 a"):
+ url = link.get("href")
+ if url and url.startswith("http"):
+ results.append(url)
+ if len(results) >= max_results:
+ break
+
+ return results
+ except Exception as e:
+ raise SearchRequestError(f"Bing search failed: {str(e)}")
+
+
+def _search_searxng(
+ query: str, max_results: int, port: int, timeout: int
+) -> List[str]:
+ """
+ Helper function for SearXNG search.
+
+ Args:
+ query (str): Search query
+ max_results (int): Maximum number of results to return
+ port (int): Port for SearXNG
+ timeout (int): Request timeout in seconds
+
+ Returns:
+ List[str]: List of URLs from search results
+ """
+ headers = {
+ "User-Agent": get_random_user_agent()
+ }
+
params = {
"q": query,
"format": "json",
- "engines": "google,duckduckgo,brave,qwant,bing",
+ "categories": "general",
+ "language": "en",
+ "time_range": "",
+ "engines": "duckduckgo,bing,brave",
+ "results": max_results
}
- response = requests.get(url, params=params, timeout=timeout)
- response.raise_for_status()
- return [
- result["url"] for result in response.json().get("results", [])[:max_results]
- ]
+
+ try:
+ response = requests.get(
+ f"http://localhost:{port}/search",
+ params=params,
+ headers=headers,
+ timeout=timeout
+ )
+ response.raise_for_status()
+
+ json_data = response.json()
+ results = [result["url"] for result in json_data.get("results", [])]
+ return results[:max_results]
+ except Exception as e:
+ raise SearchRequestError(f"SearXNG search failed: {str(e)}")
def _search_serper(
- query: str, max_results: int, serper_api_key: str, timeout: int
+ query: str, max_results: int, api_key: str, timeout: int
) -> List[str]:
- """Helper function for Serper API to get Google search results"""
- if not serper_api_key:
- raise ValueError("API key is required for Serper API")
-
- url = "https://google.serper.dev/search"
- payload = {"q": query, "num": max_results}
-
- headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"}
-
+ """
+ Helper function for Serper search.
+
+ Args:
+ query (str): Search query
+ max_results (int): Maximum number of results to return
+ api_key (str): API key for Serper
+ timeout (int): Request timeout in seconds
+
+ Returns:
+ List[str]: List of URLs from search results
+ """
+ if not api_key:
+ raise SearchConfigError("Serper API key is required")
+
+ headers = {
+ "X-API-KEY": api_key,
+ "Content-Type": "application/json"
+ }
+
+ data = {
+ "q": query,
+ "num": max_results
+ }
+
try:
response = requests.post(
- url,
+ "https://google.serper.dev/search",
+ json=data,
headers=headers,
- json=payload, # requests will handle JSON serialization
- timeout=timeout,
+ timeout=timeout
)
response.raise_for_status()
-
- # Extract only the organic search results
- results = response.json()
- organic_results = results.get("organic", [])
- urls = [result.get("link") for result in organic_results if result.get("link")]
-
- return urls[:max_results]
-
- except requests.exceptions.RequestException as e:
- raise RuntimeError(f"Serper API request failed: {str(e)}")
-
-
-def format_proxy(proxy):
- if isinstance(proxy, dict):
- server = proxy.get("server")
- username = proxy.get("username")
- password = proxy.get("password")
-
- if all([username, password, server]):
- proxy_url = f"http://{username}:{password}@{server}"
- return proxy_url
- else:
- raise ValueError("Proxy dictionary is missing required fields.")
- elif isinstance(proxy, str):
- return proxy # "https://username:password@ip:port"
- else:
- raise TypeError("Proxy should be a dictionary or a string.")
+
+ json_data = response.json()
+ results = []
+
+ # Extract organic search results
+ for item in json_data.get("organic", []):
+ if "link" in item:
+ results.append(item["link"])
+ if len(results) >= max_results:
+ break
+
+ return results
+ except Exception as e:
+ raise SearchRequestError(f"Serper search failed: {str(e)}")
+
+
+def format_proxy(proxy_config: Union[str, Dict, ProxyConfig]) -> str:
+ """
+ Format proxy configuration into a string.
+
+ Args:
+ proxy_config: Proxy configuration as string, dict, or ProxyConfig
+
+ Returns:
+ str: Formatted proxy string
+ """
+ if isinstance(proxy_config, str):
+ return proxy_config
+
+ if isinstance(proxy_config, dict):
+ proxy_config = ProxyConfig(**proxy_config)
+
+ # Format proxy with authentication if provided
+ if proxy_config.username and proxy_config.password:
+ auth = f"{proxy_config.username}:{proxy_config.password}@"
+ return f"http://{auth}{proxy_config.server}"
+
+ return f"http://{proxy_config.server}"
+
+
+def filter_pdf_links(urls: List[str]) -> List[str]:
+ """
+ Filter out PDF links from search results.
+
+ Args:
+ urls (List[str]): List of URLs
+
+ Returns:
+ List[str]: Filtered list of URLs without PDFs
+ """
+ return [url for url in urls if not PDF_REGEX.search(url)]
-def filter_pdf_links(links: List[str]) -> List[str]:
+def verify_request_signature(request_data: Dict, signature: str, secret_key: str) -> bool:
"""
- Filters out any links that point to PDF files.
-
+ Verify the signature of an incoming request.
+
Args:
- links (List[str]): A list of URLs as strings.
-
+ request_data (Dict): Request data to verify
+ signature (str): Provided signature
+ secret_key (str): Secret key for verification
+
Returns:
- List[str]: A list of URLs excluding any that end with '.pdf'.
+ bool: True if signature is valid, False otherwise
"""
- return [link for link in links if not link.lower().endswith(".pdf")]
+ import hmac
+ import hashlib
+ import json
+
+ # Sort keys for consistent serialization
+ data_string = json.dumps(request_data, sort_keys=True)
+
+ # Create HMAC signature
+ computed_signature = hmac.new(
+ secret_key.encode(),
+ data_string.encode(),
+ hashlib.sha256
+ ).hexdigest()
+
+ # Compare signatures using constant-time comparison to prevent timing attacks
+ return hmac.compare_digest(computed_signature, signature)
\ No newline at end of file