From b552aa902fb4f5052468148851434062d8e74b94 Mon Sep 17 00:00:00 2001 From: souvik03-136 <66234771+souvik03-136@users.noreply.github.com> Date: Mon, 28 Apr 2025 21:09:13 +0530 Subject: [PATCH] feat: enhance error handling and validation across utility modules - Add Pydantic models for state validation in code_error_analysis.py and code_error_correction.py - Implement comprehensive key existence checks to prevent KeyError exceptions - Create custom exception hierarchy for better error management - Add improved PDF detection with regex pattern matching in research_web.py - Implement input validation for all public functions - Add detailed error messages and type hints --- scrapegraphai/utils/code_error_analysis.py | 284 +++++++++-- scrapegraphai/utils/code_error_correction.py | 261 ++++++++-- scrapegraphai/utils/research_web.py | 496 +++++++++++++++---- 3 files changed, 838 insertions(+), 203 deletions(-) diff --git a/scrapegraphai/utils/code_error_analysis.py b/scrapegraphai/utils/code_error_analysis.py index 673c0dfe..f0642cac 100644 --- a/scrapegraphai/utils/code_error_analysis.py +++ b/scrapegraphai/utils/code_error_analysis.py @@ -12,8 +12,9 @@ """ import json -from typing import Any, Dict +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field, validator from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser @@ -25,7 +26,77 @@ ) -def syntax_focused_analysis(state: dict, llm_model) -> str: +class AnalysisError(Exception): + """Base exception for code analysis errors.""" + pass + + +class InvalidStateError(AnalysisError): + """Exception raised when state dictionary is missing required keys.""" + pass + + +class CodeAnalysisState(BaseModel): + """Base model for code analysis state validation.""" + generated_code: str = Field(..., description="The generated code to analyze") + errors: Dict[str, Any] = Field(..., description="Dictionary containing error information") + + @validator('errors') + def validate_errors(cls, v): + """Ensure errors dictionary has expected structure.""" + if not isinstance(v, dict): + raise ValueError("errors must be a dictionary") + return v + + +class ExecutionAnalysisState(CodeAnalysisState): + """Model for execution analysis state validation.""" + html_code: Optional[str] = Field(None, description="HTML code if available") + html_analysis: Optional[str] = Field(None, description="Analysis of HTML code") + + @validator('errors') + def validate_execution_errors(cls, v): + """Ensure errors dictionary contains execution key.""" + super().validate_errors(v) + if 'execution' not in v: + raise ValueError("errors dictionary must contain 'execution' key") + return v + + +class ValidationAnalysisState(CodeAnalysisState): + """Model for validation analysis state validation.""" + json_schema: Dict[str, Any] = Field(..., description="JSON schema for validation") + execution_result: Any = Field(..., description="Result of code execution") + + @validator('errors') + def validate_validation_errors(cls, v): + """Ensure errors dictionary contains validation key.""" + super().validate_errors(v) + if 'validation' not in v: + raise ValueError("errors dictionary must contain 'validation' key") + return v + + +def get_optimal_analysis_template(error_type: str) -> str: + """ + Returns the optimal prompt template based on the error type. + + Args: + error_type (str): Type of error to analyze. + + Returns: + str: The prompt template text. + """ + template_registry = { + "syntax": TEMPLATE_SYNTAX_ANALYSIS, + "execution": TEMPLATE_EXECUTION_ANALYSIS, + "validation": TEMPLATE_VALIDATION_ANALYSIS, + "semantic": TEMPLATE_SEMANTIC_ANALYSIS, + } + return template_registry.get(error_type, TEMPLATE_SYNTAX_ANALYSIS) + + +def syntax_focused_analysis(state: Dict[str, Any], llm_model) -> str: """ Analyzes the syntax errors in the generated code. @@ -35,17 +106,48 @@ def syntax_focused_analysis(state: dict, llm_model) -> str: Returns: str: The result of the syntax error analysis. + + Raises: + InvalidStateError: If state is missing required keys. + + Example: + >>> state = { + 'generated_code': 'print("Hello World")', + 'errors': {'syntax': 'Missing parenthesis'} + } + >>> analysis = syntax_focused_analysis(state, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_SYNTAX_ANALYSIS, input_variables=["generated_code", "errors"] - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - {"generated_code": state["generated_code"], "errors": state["errors"]["syntax"]} - ) + try: + # Validate state using Pydantic model + validated_state = CodeAnalysisState( + generated_code=state.get("generated_code", ""), + errors=state.get("errors", {}) + ) + + # Check if syntax errors exist + if "syntax" not in validated_state.errors: + raise InvalidStateError("No syntax errors found in state dictionary") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_analysis_template("syntax"), + input_variables=["generated_code", "errors"] + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ + "generated_code": validated_state.generated_code, + "errors": validated_state.errors["syntax"] + }) + + except KeyError as e: + raise InvalidStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise AnalysisError(f"Syntax analysis failed: {str(e)}") -def execution_focused_analysis(state: dict, llm_model) -> str: +def execution_focused_analysis(state: Dict[str, Any], llm_model) -> str: """ Analyzes the execution errors in the generated code and HTML code. @@ -55,23 +157,50 @@ def execution_focused_analysis(state: dict, llm_model) -> str: Returns: str: The result of the execution error analysis. - """ - prompt = PromptTemplate( - template=TEMPLATE_EXECUTION_ANALYSIS, - input_variables=["generated_code", "errors", "html_code", "html_analysis"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - { - "generated_code": state["generated_code"], - "errors": state["errors"]["execution"], - "html_code": state["html_code"], - "html_analysis": state["html_analysis"], + + Raises: + InvalidStateError: If state is missing required keys. + + Example: + >>> state = { + 'generated_code': 'print(x)', + 'errors': {'execution': 'NameError: name "x" is not defined'}, + 'html_code': '
Test
', + 'html_analysis': 'Valid HTML' } - ) + >>> analysis = execution_focused_analysis(state, mock_llm) + """ + try: + # Validate state using Pydantic model + validated_state = ExecutionAnalysisState( + generated_code=state.get("generated_code", ""), + errors=state.get("errors", {}), + html_code=state.get("html_code", ""), + html_analysis=state.get("html_analysis", "") + ) + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_analysis_template("execution"), + input_variables=["generated_code", "errors", "html_code", "html_analysis"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ + "generated_code": validated_state.generated_code, + "errors": validated_state.errors["execution"], + "html_code": validated_state.html_code, + "html_analysis": validated_state.html_analysis, + }) + + except KeyError as e: + raise InvalidStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise AnalysisError(f"Execution analysis failed: {str(e)}") -def validation_focused_analysis(state: dict, llm_model) -> str: +def validation_focused_analysis(state: Dict[str, Any], llm_model) -> str: """ Analyzes the validation errors in the generated code based on a JSON schema. @@ -82,24 +211,51 @@ def validation_focused_analysis(state: dict, llm_model) -> str: Returns: str: The result of the validation error analysis. - """ - prompt = PromptTemplate( - template=TEMPLATE_VALIDATION_ANALYSIS, - input_variables=["generated_code", "errors", "json_schema", "execution_result"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - { - "generated_code": state["generated_code"], - "errors": state["errors"]["validation"], - "json_schema": state["json_schema"], - "execution_result": state["execution_result"], + + Raises: + InvalidStateError: If state is missing required keys. + + Example: + >>> state = { + 'generated_code': 'return {"name": "John"}', + 'errors': {'validation': 'Missing required field: age'}, + 'json_schema': {'required': ['name', 'age']}, + 'execution_result': {'name': 'John'} } - ) + >>> analysis = validation_focused_analysis(state, mock_llm) + """ + try: + # Validate state using Pydantic model + validated_state = ValidationAnalysisState( + generated_code=state.get("generated_code", ""), + errors=state.get("errors", {}), + json_schema=state.get("json_schema", {}), + execution_result=state.get("execution_result", {}) + ) + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_analysis_template("validation"), + input_variables=["generated_code", "errors", "json_schema", "execution_result"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ + "generated_code": validated_state.generated_code, + "errors": validated_state.errors["validation"], + "json_schema": validated_state.json_schema, + "execution_result": validated_state.execution_result, + }) + + except KeyError as e: + raise InvalidStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise AnalysisError(f"Validation analysis failed: {str(e)}") def semantic_focused_analysis( - state: dict, comparison_result: Dict[str, Any], llm_model + state: Dict[str, Any], comparison_result: Dict[str, Any], llm_model ) -> str: """ Analyzes the semantic differences in the generated code based on a comparison result. @@ -112,16 +268,48 @@ def semantic_focused_analysis( Returns: str: The result of the semantic error analysis. + + Raises: + InvalidStateError: If state or comparison_result is missing required keys. + + Example: + >>> state = { + 'generated_code': 'def add(a, b): return a + b' + } + >>> comparison_result = { + 'differences': ['Missing docstring', 'No type hints'], + 'explanation': 'The code is missing documentation' + } + >>> analysis = semantic_focused_analysis(state, comparison_result, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_SEMANTIC_ANALYSIS, - input_variables=["generated_code", "differences", "explanation"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - { - "generated_code": state["generated_code"], + try: + # Validate state using Pydantic model + validated_state = CodeAnalysisState( + generated_code=state.get("generated_code", ""), + errors=state.get("errors", {}) + ) + + # Validate comparison_result + if "differences" not in comparison_result: + raise InvalidStateError("comparison_result missing 'differences' key") + if "explanation" not in comparison_result: + raise InvalidStateError("comparison_result missing 'explanation' key") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_analysis_template("semantic"), + input_variables=["generated_code", "differences", "explanation"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated inputs + return chain.invoke({ + "generated_code": validated_state.generated_code, "differences": json.dumps(comparison_result["differences"], indent=2), "explanation": comparison_result["explanation"], - } - ) + }) + + except KeyError as e: + raise InvalidStateError(f"Missing required key: {e}") + except Exception as e: + raise AnalysisError(f"Semantic analysis failed: {str(e)}") \ No newline at end of file diff --git a/scrapegraphai/utils/code_error_correction.py b/scrapegraphai/utils/code_error_correction.py index e73237ad..b3838422 100644 --- a/scrapegraphai/utils/code_error_correction.py +++ b/scrapegraphai/utils/code_error_correction.py @@ -11,7 +11,10 @@ """ import json +from typing import Any, Dict, Optional +from functools import lru_cache +from pydantic import BaseModel, Field, validator from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser @@ -23,7 +26,57 @@ ) -def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str: +class CodeGenerationError(Exception): + """Base exception for code generation errors.""" + pass + + +class InvalidCorrectionStateError(CodeGenerationError): + """Exception raised when state dictionary is missing required keys.""" + pass + + +class CorrectionState(BaseModel): + """Base model for code correction state validation.""" + generated_code: str = Field(..., description="The original generated code to correct") + + class Config: + extra = "allow" + + +class ValidationCorrectionState(CorrectionState): + """Model for validation correction state validation.""" + json_schema: Dict[str, Any] = Field(..., description="JSON schema for validation") + + +class SemanticCorrectionState(CorrectionState): + """Model for semantic correction state validation.""" + execution_result: Any = Field(..., description="Result of code execution") + reference_answer: Any = Field(..., description="Reference answer for comparison") + + +@lru_cache(maxsize=32) +def get_optimal_correction_template(error_type: str) -> str: + """ + Returns the optimal prompt template for code correction based on the error type. + Results are cached for performance. + + Args: + error_type (str): Type of error to correct. + + Returns: + str: The prompt template text. + """ + template_registry = { + "syntax": TEMPLATE_SYNTAX_CODE_GENERATION, + "execution": TEMPLATE_EXECUTION_CODE_GENERATION, + "validation": TEMPLATE_VALIDATION_CODE_GENERATION, + "semantic": TEMPLATE_SEMANTIC_CODE_GENERATION, + } + return template_registry.get(error_type, TEMPLATE_SYNTAX_CODE_GENERATION) + + +def syntax_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str: """ Generates corrected code based on syntax error analysis. @@ -34,18 +87,46 @@ def syntax_focused_code_generation(state: dict, analysis: str, llm_model) -> str Returns: str: The corrected code. + + Raises: + InvalidCorrectionStateError: If state is missing required keys. + + Example: + >>> state = { + 'generated_code': 'print("Hello World"' + } + >>> analysis = "Missing closing parenthesis in print statement" + >>> corrected_code = syntax_focused_code_generation(state, analysis, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_SYNTAX_CODE_GENERATION, - input_variables=["analysis", "generated_code"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - {"analysis": analysis, "generated_code": state["generated_code"]} - ) + try: + # Validate state using Pydantic model + validated_state = CorrectionState( + generated_code=state.get("generated_code", "") + ) + + if not analysis or not isinstance(analysis, str): + raise InvalidCorrectionStateError("Analysis must be a non-empty string") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_correction_template("syntax"), + input_variables=["analysis", "generated_code"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ + "analysis": analysis, + "generated_code": validated_state.generated_code + }) + + except KeyError as e: + raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise CodeGenerationError(f"Syntax code generation failed: {str(e)}") -def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> str: +def execution_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str: """ Generates corrected code based on execution error analysis. @@ -56,18 +137,46 @@ def execution_focused_code_generation(state: dict, analysis: str, llm_model) -> Returns: str: The corrected code. + + Raises: + InvalidCorrectionStateError: If state is missing required keys or analysis is invalid. + + Example: + >>> state = { + 'generated_code': 'print(x)' + } + >>> analysis = "Variable 'x' is not defined before use" + >>> corrected_code = execution_focused_code_generation(state, analysis, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_EXECUTION_CODE_GENERATION, - input_variables=["analysis", "generated_code"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - {"analysis": analysis, "generated_code": state["generated_code"]} - ) + try: + # Validate state using Pydantic model + validated_state = CorrectionState( + generated_code=state.get("generated_code", "") + ) + + if not analysis or not isinstance(analysis, str): + raise InvalidCorrectionStateError("Analysis must be a non-empty string") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_correction_template("execution"), + input_variables=["analysis", "generated_code"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ + "analysis": analysis, + "generated_code": validated_state.generated_code + }) + + except KeyError as e: + raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise CodeGenerationError(f"Execution code generation failed: {str(e)}") -def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> str: +def validation_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str: """ Generates corrected code based on validation error analysis. @@ -78,22 +187,49 @@ def validation_focused_code_generation(state: dict, analysis: str, llm_model) -> Returns: str: The corrected code. + + Raises: + InvalidCorrectionStateError: If state is missing required keys or analysis is invalid. + + Example: + >>> state = { + 'generated_code': 'return {"name": "John"}', + 'json_schema': {'required': ['name', 'age']} + } + >>> analysis = "The output JSON is missing the required 'age' field" + >>> corrected_code = validation_focused_code_generation(state, analysis, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_VALIDATION_CODE_GENERATION, - input_variables=["analysis", "generated_code", "json_schema"], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - { + try: + # Validate state using Pydantic model + validated_state = ValidationCorrectionState( + generated_code=state.get("generated_code", ""), + json_schema=state.get("json_schema", {}) + ) + + if not analysis or not isinstance(analysis, str): + raise InvalidCorrectionStateError("Analysis must be a non-empty string") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_correction_template("validation"), + input_variables=["analysis", "generated_code", "json_schema"], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ "analysis": analysis, - "generated_code": state["generated_code"], - "json_schema": state["json_schema"], - } - ) + "generated_code": validated_state.generated_code, + "json_schema": validated_state.json_schema, + }) + + except KeyError as e: + raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise CodeGenerationError(f"Validation code generation failed: {str(e)}") -def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> str: +def semantic_focused_code_generation(state: Dict[str, Any], analysis: str, llm_model) -> str: """ Generates corrected code based on semantic error analysis. @@ -104,22 +240,51 @@ def semantic_focused_code_generation(state: dict, analysis: str, llm_model) -> s Returns: str: The corrected code. + + Raises: + InvalidCorrectionStateError: If state is missing required keys or analysis is invalid. + + Example: + >>> state = { + 'generated_code': 'def add(a, b): return a + b', + 'execution_result': {'result': 3}, + 'reference_answer': {'result': 3, 'documentation': 'Adds two numbers'} + } + >>> analysis = "The code is missing documentation" + >>> corrected_code = semantic_focused_code_generation(state, analysis, mock_llm) """ - prompt = PromptTemplate( - template=TEMPLATE_SEMANTIC_CODE_GENERATION, - input_variables=[ - "analysis", - "generated_code", - "generated_result", - "reference_result", - ], - ) - chain = prompt | llm_model | StrOutputParser() - return chain.invoke( - { + try: + # Validate state using Pydantic model + validated_state = SemanticCorrectionState( + generated_code=state.get("generated_code", ""), + execution_result=state.get("execution_result", {}), + reference_answer=state.get("reference_answer", {}) + ) + + if not analysis or not isinstance(analysis, str): + raise InvalidCorrectionStateError("Analysis must be a non-empty string") + + # Create prompt template and chain + prompt = PromptTemplate( + template=get_optimal_correction_template("semantic"), + input_variables=[ + "analysis", + "generated_code", + "generated_result", + "reference_result", + ], + ) + chain = prompt | llm_model | StrOutputParser() + + # Execute chain with validated state + return chain.invoke({ "analysis": analysis, - "generated_code": state["generated_code"], - "generated_result": json.dumps(state["execution_result"], indent=2), - "reference_result": json.dumps(state["reference_answer"], indent=2), - } - ) + "generated_code": validated_state.generated_code, + "generated_result": json.dumps(validated_state.execution_result, indent=2), + "reference_result": json.dumps(validated_state.reference_answer, indent=2), + }) + + except KeyError as e: + raise InvalidCorrectionStateError(f"Missing required key in state dictionary: {e}") + except Exception as e: + raise CodeGenerationError(f"Semantic code generation failed: {str(e)}") \ No newline at end of file diff --git a/scrapegraphai/utils/research_web.py b/scrapegraphai/utils/research_web.py index b9721306..195e11ca 100644 --- a/scrapegraphai/utils/research_web.py +++ b/scrapegraphai/utils/research_web.py @@ -1,27 +1,161 @@ """ -research_web module +research_web module for web searching across different search engines with improved +error handling, validation, and security features. """ import re -from typing import List +import random +import time +from typing import List, Dict, Union, Optional +from functools import wraps import requests from bs4 import BeautifulSoup +from pydantic import BaseModel, Field, validator from langchain_community.tools import DuckDuckGoSearchResults +class ResearchWebError(Exception): + """Base exception for research web errors.""" + pass + + +class SearchConfigError(ResearchWebError): + """Exception raised when search configuration is invalid.""" + pass + + +class SearchRequestError(ResearchWebError): + """Exception raised when search request fails.""" + pass + + +class ProxyConfig(BaseModel): + """Model for proxy configuration validation.""" + server: str = Field(..., description="Proxy server address including port") + username: Optional[str] = Field(None, description="Username for proxy authentication") + password: Optional[str] = Field(None, description="Password for proxy authentication") + + +class SearchConfig(BaseModel): + """Model for search configuration validation.""" + query: str = Field(..., description="Search query") + search_engine: str = Field("duckduckgo", description="Search engine to use") + max_results: int = Field(10, description="Maximum number of results to return") + port: Optional[int] = Field(8080, description="Port for SearXNG") + timeout: int = Field(10, description="Request timeout in seconds") + proxy: Optional[Union[str, Dict, ProxyConfig]] = Field(None, description="Proxy configuration") + serper_api_key: Optional[str] = Field(None, description="API key for Serper") + region: Optional[str] = Field(None, description="Country/region code") + language: str = Field("en", description="Language code") + + @validator('search_engine') + def validate_search_engine(cls, v): + """Validate search engine.""" + valid_engines = {"duckduckgo", "bing", "searxng", "serper"} + if v.lower() not in valid_engines: + raise ValueError(f"Search engine must be one of: {', '.join(valid_engines)}") + return v.lower() + + @validator('query') + def validate_query(cls, v): + """Validate search query.""" + if not v or not isinstance(v, str): + raise ValueError("Query must be a non-empty string") + return v + + @validator('max_results') + def validate_max_results(cls, v): + """Validate max results.""" + if v < 1 or v > 100: + raise ValueError("max_results must be between 1 and 100") + return v + + +# Define advanced PDF detection regex +PDF_REGEX = re.compile(r'\.pdf(#.*)?(\?.*)?$', re.IGNORECASE) + + +# Rate limiting decorator +def rate_limited(calls: int, period: int = 60): + """ + Decorator to limit the rate of function calls. + + Args: + calls (int): Maximum number of calls allowed in the period. + period (int): Time period in seconds. + + Returns: + Callable: Decorated function with rate limiting. + """ + min_interval = period / float(calls) + last_called = [0.0] + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + elapsed = time.time() - last_called[0] + wait_time = min_interval - elapsed + if wait_time > 0: + time.sleep(wait_time) + result = func(*args, **kwargs) + last_called[0] = time.time() + return result + return wrapper + return decorator + + +def sanitize_search_query(query: str) -> str: + """ + Sanitizes search query to prevent injection attacks. + + Args: + query (str): The search query. + + Returns: + str: Sanitized query. + """ + # Remove potential command injection characters + sanitized = re.sub(r'[;&|`$()\[\]{}<>]', '', query) + # Trim whitespace + sanitized = sanitized.strip() + return sanitized + + +# List of user agents for rotation +USER_AGENTS = [ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36", + "Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1" +] + + +def get_random_user_agent() -> str: + """ + Returns a random user agent from the list. + + Returns: + str: Random user agent string. + """ + return random.choice(USER_AGENTS) + + +@rate_limited(calls=10, period=60) def search_on_web( query: str, search_engine: str = "duckduckgo", max_results: int = 10, port: int = 8080, timeout: int = 10, - proxy: str | dict = None, - serper_api_key: str = None, - region: str = None, + proxy: Optional[Union[str, Dict, ProxyConfig]] = None, + serper_api_key: Optional[str] = None, + region: Optional[str] = None, language: str = "en", ) -> List[str]: - """Search web function with improved error handling and validation + """ + Search web function with improved error handling, validation, and security features. Args: query (str): Search query @@ -29,147 +163,295 @@ def search_on_web( max_results (int): Maximum number of results to return port (int): Port for SearXNG timeout (int): Request timeout in seconds - proxy (str | dict): Proxy configuration + proxy (str | dict | ProxyConfig): Proxy configuration serper_api_key (str): API key for Serper region (str): Country/region code (e.g., 'mx' for Mexico) language (str): Language code (e.g., 'es' for Spanish) + + Returns: + List[str]: List of URLs from search results + + Raises: + SearchConfigError: If search configuration is invalid + SearchRequestError: If search request fails + TimeoutError: If search request times out """ - - # Input validation - if not query or not isinstance(query, str): - raise ValueError("Query must be a non-empty string") - - search_engine = search_engine.lower() - valid_engines = {"duckduckgo", "bing", "searxng", "serper"} - if search_engine not in valid_engines: - raise ValueError(f"Search engine must be one of: {', '.join(valid_engines)}") - - # Format proxy once - formatted_proxy = None - if proxy: - formatted_proxy = format_proxy(proxy) - try: + # Sanitize query for security + sanitized_query = sanitize_search_query(query) + + # Validate search configuration + config = SearchConfig( + query=sanitized_query, + search_engine=search_engine, + max_results=max_results, + port=port, + timeout=timeout, + proxy=proxy, + serper_api_key=serper_api_key, + region=region, + language=language + ) + + # Format proxy once + formatted_proxy = None + if config.proxy: + formatted_proxy = format_proxy(config.proxy) + results = [] - if search_engine == "duckduckgo": + if config.search_engine == "duckduckgo": # Create a DuckDuckGo search object with max_results - research = DuckDuckGoSearchResults(max_results=max_results) + research = DuckDuckGoSearchResults(max_results=config.max_results) # Run the search - res = research.run(query) + res = research.run(config.query) # Extract URLs using regex results = re.findall(r"https?://[^\s,\]]+", res) - elif search_engine == "bing": - results = _search_bing(query, max_results, timeout, formatted_proxy) - - elif search_engine == "searxng": - results = _search_searxng(query, max_results, port, timeout) - - elif search_engine == "serper": - results = _search_serper(query, max_results, serper_api_key, timeout) + elif config.search_engine == "bing": + results = _search_bing( + config.query, + config.max_results, + config.timeout, + formatted_proxy + ) + + elif config.search_engine == "searxng": + results = _search_searxng( + config.query, + config.max_results, + config.port, + config.timeout + ) + + elif config.search_engine == "serper": + results = _search_serper( + config.query, + config.max_results, + config.serper_api_key, + config.timeout + ) return filter_pdf_links(results) except requests.Timeout: raise TimeoutError(f"Search request timed out after {timeout} seconds") except requests.RequestException as e: - raise RuntimeError(f"Search request failed: {str(e)}") + raise SearchRequestError(f"Search request failed: {str(e)}") + except ValueError as e: + raise SearchConfigError(f"Invalid search configuration: {str(e)}") def _search_bing( - query: str, max_results: int, timeout: int, proxy: str = None + query: str, max_results: int, timeout: int, proxy: Optional[str] = None ) -> List[str]: - """Helper function for Bing search""" + """ + Helper function for Bing search with improved error handling. + + Args: + query (str): Search query + max_results (int): Maximum number of results to return + timeout (int): Request timeout in seconds + proxy (str, optional): Proxy configuration + + Returns: + List[str]: List of URLs from search results + """ headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": get_random_user_agent() } - search_url = f"https://www.bing.com/search?q={query}" - + + params = { + "q": query, + "count": max_results + } + proxies = {"http": proxy, "https": proxy} if proxy else None - response = requests.get( - search_url, headers=headers, timeout=timeout, proxies=proxies - ) - response.raise_for_status() - - soup = BeautifulSoup(response.text, "html.parser") - return [ - result.find("a")["href"] - for result in soup.find_all("li", class_="b_algo", limit=max_results) - ] - - -def _search_searxng(query: str, max_results: int, port: int, timeout: int) -> List[str]: - """Helper function for SearXNG search""" - url = f"http://localhost:{port}/search" + + try: + response = requests.get( + "https://www.bing.com/search", + params=params, + headers=headers, + proxies=proxies, + timeout=timeout + ) + response.raise_for_status() + + soup = BeautifulSoup(response.text, "html.parser") + results = [] + + # Extract URLs from Bing search results + for link in soup.select("li.b_algo h2 a"): + url = link.get("href") + if url and url.startswith("http"): + results.append(url) + if len(results) >= max_results: + break + + return results + except Exception as e: + raise SearchRequestError(f"Bing search failed: {str(e)}") + + +def _search_searxng( + query: str, max_results: int, port: int, timeout: int +) -> List[str]: + """ + Helper function for SearXNG search. + + Args: + query (str): Search query + max_results (int): Maximum number of results to return + port (int): Port for SearXNG + timeout (int): Request timeout in seconds + + Returns: + List[str]: List of URLs from search results + """ + headers = { + "User-Agent": get_random_user_agent() + } + params = { "q": query, "format": "json", - "engines": "google,duckduckgo,brave,qwant,bing", + "categories": "general", + "language": "en", + "time_range": "", + "engines": "duckduckgo,bing,brave", + "results": max_results } - response = requests.get(url, params=params, timeout=timeout) - response.raise_for_status() - return [ - result["url"] for result in response.json().get("results", [])[:max_results] - ] + + try: + response = requests.get( + f"http://localhost:{port}/search", + params=params, + headers=headers, + timeout=timeout + ) + response.raise_for_status() + + json_data = response.json() + results = [result["url"] for result in json_data.get("results", [])] + return results[:max_results] + except Exception as e: + raise SearchRequestError(f"SearXNG search failed: {str(e)}") def _search_serper( - query: str, max_results: int, serper_api_key: str, timeout: int + query: str, max_results: int, api_key: str, timeout: int ) -> List[str]: - """Helper function for Serper API to get Google search results""" - if not serper_api_key: - raise ValueError("API key is required for Serper API") - - url = "https://google.serper.dev/search" - payload = {"q": query, "num": max_results} - - headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"} - + """ + Helper function for Serper search. + + Args: + query (str): Search query + max_results (int): Maximum number of results to return + api_key (str): API key for Serper + timeout (int): Request timeout in seconds + + Returns: + List[str]: List of URLs from search results + """ + if not api_key: + raise SearchConfigError("Serper API key is required") + + headers = { + "X-API-KEY": api_key, + "Content-Type": "application/json" + } + + data = { + "q": query, + "num": max_results + } + try: response = requests.post( - url, + "https://google.serper.dev/search", + json=data, headers=headers, - json=payload, # requests will handle JSON serialization - timeout=timeout, + timeout=timeout ) response.raise_for_status() - - # Extract only the organic search results - results = response.json() - organic_results = results.get("organic", []) - urls = [result.get("link") for result in organic_results if result.get("link")] - - return urls[:max_results] - - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Serper API request failed: {str(e)}") - - -def format_proxy(proxy): - if isinstance(proxy, dict): - server = proxy.get("server") - username = proxy.get("username") - password = proxy.get("password") - - if all([username, password, server]): - proxy_url = f"http://{username}:{password}@{server}" - return proxy_url - else: - raise ValueError("Proxy dictionary is missing required fields.") - elif isinstance(proxy, str): - return proxy # "https://username:password@ip:port" - else: - raise TypeError("Proxy should be a dictionary or a string.") + + json_data = response.json() + results = [] + + # Extract organic search results + for item in json_data.get("organic", []): + if "link" in item: + results.append(item["link"]) + if len(results) >= max_results: + break + + return results + except Exception as e: + raise SearchRequestError(f"Serper search failed: {str(e)}") + + +def format_proxy(proxy_config: Union[str, Dict, ProxyConfig]) -> str: + """ + Format proxy configuration into a string. + + Args: + proxy_config: Proxy configuration as string, dict, or ProxyConfig + + Returns: + str: Formatted proxy string + """ + if isinstance(proxy_config, str): + return proxy_config + + if isinstance(proxy_config, dict): + proxy_config = ProxyConfig(**proxy_config) + + # Format proxy with authentication if provided + if proxy_config.username and proxy_config.password: + auth = f"{proxy_config.username}:{proxy_config.password}@" + return f"http://{auth}{proxy_config.server}" + + return f"http://{proxy_config.server}" + + +def filter_pdf_links(urls: List[str]) -> List[str]: + """ + Filter out PDF links from search results. + + Args: + urls (List[str]): List of URLs + + Returns: + List[str]: Filtered list of URLs without PDFs + """ + return [url for url in urls if not PDF_REGEX.search(url)] -def filter_pdf_links(links: List[str]) -> List[str]: +def verify_request_signature(request_data: Dict, signature: str, secret_key: str) -> bool: """ - Filters out any links that point to PDF files. - + Verify the signature of an incoming request. + Args: - links (List[str]): A list of URLs as strings. - + request_data (Dict): Request data to verify + signature (str): Provided signature + secret_key (str): Secret key for verification + Returns: - List[str]: A list of URLs excluding any that end with '.pdf'. + bool: True if signature is valid, False otherwise """ - return [link for link in links if not link.lower().endswith(".pdf")] + import hmac + import hashlib + import json + + # Sort keys for consistent serialization + data_string = json.dumps(request_data, sort_keys=True) + + # Create HMAC signature + computed_signature = hmac.new( + secret_key.encode(), + data_string.encode(), + hashlib.sha256 + ).hexdigest() + + # Compare signatures using constant-time comparison to prevent timing attacks + return hmac.compare_digest(computed_signature, signature) \ No newline at end of file