diff --git a/scrapegraphai/graphs/csv_scraper_multi_graph.py b/scrapegraphai/graphs/csv_scraper_multi_graph.py index 59e84783..67498475 100644 --- a/scrapegraphai/graphs/csv_scraper_multi_graph.py +++ b/scrapegraphai/graphs/csv_scraper_multi_graph.py @@ -2,9 +2,10 @@ CSVScraperMultiGraph Module """ -from copy import copy, deepcopy from typing import List, Optional from pydantic import BaseModel + + from .base_graph import BaseGraph from .abstract_graph import AbstractGraph from .csv_scraper_graph import CSVScraperGraph @@ -12,6 +13,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class CSVScraperMultiGraph(AbstractGraph): """ @@ -46,10 +48,7 @@ def __init__(self, prompt: str, source: List[str], self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/json_scraper_multi_graph.py b/scrapegraphai/graphs/json_scraper_multi_graph.py index 42d2232e..c72d8afd 100644 --- a/scrapegraphai/graphs/json_scraper_multi_graph.py +++ b/scrapegraphai/graphs/json_scraper_multi_graph.py @@ -2,9 +2,10 @@ JSONScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel + from .base_graph import BaseGraph from .abstract_graph import AbstractGraph from .json_scraper_graph import JSONScraperGraph @@ -12,6 +13,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class JSONScraperMultiGraph(AbstractGraph): """ @@ -45,10 +47,7 @@ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optiona self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/markdown_scraper_multi_graph.py b/scrapegraphai/graphs/markdown_scraper_multi_graph.py index 9796c11a..772eebe6 100644 --- a/scrapegraphai/graphs/markdown_scraper_multi_graph.py +++ b/scrapegraphai/graphs/markdown_scraper_multi_graph.py @@ -12,6 +12,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class MDScraperMultiGraph(AbstractGraph): """ @@ -42,11 +43,7 @@ class MDScraperMultiGraph(AbstractGraph): """ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) - + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/omni_search_graph.py b/scrapegraphai/graphs/omni_search_graph.py index b6f6df59..c005dbac 100644 --- a/scrapegraphai/graphs/omni_search_graph.py +++ b/scrapegraphai/graphs/omni_search_graph.py @@ -2,7 +2,7 @@ OmniSearchGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import Optional from pydantic import BaseModel @@ -15,6 +15,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class OmniSearchGraph(AbstractGraph): @@ -48,10 +49,7 @@ def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/pdf_scraper_multi_graph.py b/scrapegraphai/graphs/pdf_scraper_multi_graph.py index a7386267..06da6944 100644 --- a/scrapegraphai/graphs/pdf_scraper_multi_graph.py +++ b/scrapegraphai/graphs/pdf_scraper_multi_graph.py @@ -2,7 +2,7 @@ PdfScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel from .base_graph import BaseGraph @@ -12,6 +12,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class PdfScraperMultiGraph(AbstractGraph): """ @@ -44,10 +45,7 @@ class PdfScraperMultiGraph(AbstractGraph): def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/script_creator_multi_graph.py b/scrapegraphai/graphs/script_creator_multi_graph.py index 969ba722..b2ea8465 100644 --- a/scrapegraphai/graphs/script_creator_multi_graph.py +++ b/scrapegraphai/graphs/script_creator_multi_graph.py @@ -2,7 +2,6 @@ ScriptCreatorMultiGraph Module """ -from copy import copy, deepcopy from typing import List, Optional from pydantic import BaseModel @@ -15,6 +14,7 @@ GraphIteratorNode, MergeGeneratedScriptsNode ) +from ..utils.copy import safe_deepcopy class ScriptCreatorMultiGraph(AbstractGraph): """ @@ -47,10 +47,7 @@ def __init__(self, prompt: str, source: List[str], config: dict, schema: Optiona self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) super().__init__(prompt, config, source, schema) diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 080aaf19..d27e7186 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -2,7 +2,7 @@ SearchGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import Optional, List from pydantic import BaseModel @@ -15,6 +15,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class SearchGraph(AbstractGraph): """ @@ -47,10 +48,7 @@ class SearchGraph(AbstractGraph): def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None): self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) self.considered_urls = [] # New attribute to store URLs diff --git a/scrapegraphai/graphs/smart_scraper_multi_graph.py b/scrapegraphai/graphs/smart_scraper_multi_graph.py index 66d53851..82585cf0 100644 --- a/scrapegraphai/graphs/smart_scraper_multi_graph.py +++ b/scrapegraphai/graphs/smart_scraper_multi_graph.py @@ -2,7 +2,7 @@ SmartScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel @@ -14,6 +14,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class SmartScraperMultiGraph(AbstractGraph): """ @@ -48,10 +49,7 @@ def __init__(self, prompt: str, source: List[str], self.max_results = config.get("max_results", 3) - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/graphs/xml_scraper_multi_graph.py b/scrapegraphai/graphs/xml_scraper_multi_graph.py index 8050d50c..493d12ca 100644 --- a/scrapegraphai/graphs/xml_scraper_multi_graph.py +++ b/scrapegraphai/graphs/xml_scraper_multi_graph.py @@ -2,7 +2,7 @@ XMLScraperMultiGraph Module """ -from copy import copy, deepcopy +from copy import deepcopy from typing import List, Optional from pydantic import BaseModel @@ -14,6 +14,7 @@ GraphIteratorNode, MergeAnswersNode ) +from ..utils.copy import safe_deepcopy class XMLScraperMultiGraph(AbstractGraph): """ @@ -46,10 +47,7 @@ class XMLScraperMultiGraph(AbstractGraph): def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None): - if all(isinstance(value, str) for value in config.values()): - self.copy_config = copy(config) - else: - self.copy_config = deepcopy(config) + self.copy_config = safe_deepcopy(config) self.copy_schema = deepcopy(schema) diff --git a/scrapegraphai/utils/copy.py b/scrapegraphai/utils/copy.py new file mode 100644 index 00000000..2defbfa3 --- /dev/null +++ b/scrapegraphai/utils/copy.py @@ -0,0 +1,75 @@ +import copy +from typing import Any, Dict, Optional +from pydantic.v1 import BaseModel + +class DeepCopyError(Exception): + """Custom exception raised when an object cannot be deep-copied.""" + pass + +def safe_deepcopy(obj: Any) -> Any: + """ + Attempts to create a deep copy of the object using `copy.deepcopy` + whenever possible. If that fails, it falls back to custom deep copy + logic. If that also fails, it raises a `DeepCopyError`. + + Args: + obj (Any): The object to be copied, which can be of any type. + + Returns: + Any: A deep copy of the object if possible; otherwise, a shallow + copy if deep copying fails; if neither is possible, the original + object is returned. + Raises: + DeepCopyError: If the object cannot be deep-copied or shallow-copied. + """ + + try: + + # Try to use copy.deepcopy first + return copy.deepcopy(obj) + except (TypeError, AttributeError) as e: + # If deepcopy fails, handle specific types manually + + # Handle dictionaries + if isinstance(obj, dict): + new_obj = {} + + for k, v in obj.items(): + new_obj[k] = safe_deepcopy(v) + return new_obj + + # Handle lists + elif isinstance(obj, list): + new_obj = [] + + for v in obj: + new_obj.append(safe_deepcopy(v)) + return new_obj + + # Handle tuples (immutable, but might contain mutable objects) + elif isinstance(obj, tuple): + new_obj = tuple(safe_deepcopy(v) for v in obj) + + return new_obj + + # Handle frozensets (immutable, but might contain mutable objects) + elif isinstance(obj, frozenset): + new_obj = frozenset(safe_deepcopy(v) for v in obj) + return new_obj + + # Handle objects with attributes + elif hasattr(obj, "__dict__"): + # If an object cannot be deep copied, then the sub-properties of \ + # the object will not be analyzed and shallow copy will be used directly. + try: + return copy.copy(obj) + except (TypeError, AttributeError): + raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e + + + # Attempt shallow copy as a fallback + try: + return copy.copy(obj) + except (TypeError, AttributeError): + raise DeepCopyError(f"Cannot deep copy the object of type {type(obj)}") from e + diff --git a/tests/utils/copy_utils_test.py b/tests/utils/copy_utils_test.py new file mode 100644 index 00000000..90c85d34 --- /dev/null +++ b/tests/utils/copy_utils_test.py @@ -0,0 +1,186 @@ +import copy +import pytest + +# Assuming the custom_deepcopy function is imported or defined above this line +from scrapegraphai.utils.copy import DeepCopyError, safe_deepcopy +from pydantic.v1 import BaseModel + +class PydantObject(BaseModel): + value: int + +class NormalObject: + def __init__(self, value): + self.value = value + self.nested = [1, 2, 3] + + +class NonDeepcopyable: + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("Forcing shallow copy fallback") + + +class WithoutDict: + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("Forcing shallow copy fallback") + + def __copy__(self): + return self + + +class NonCopyableObject: + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + def __deepcopy__(self, memo): + raise TypeError("fail deep copy ") + + def __copy__(self): + raise TypeError("fail shallow copy") + + +def test_deepcopy_simple_dict(): + original = {"a": 1, "b": 2, "c": [3, 4, 5]} + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj["c"] is not original["c"] + + +def test_deepcopy_simple_list(): + original = [1, 2, 3, [4, 5]] + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj[3] is not original[3] + + +def test_deepcopy_with_tuple(): + original = (1, 2, [3, 4]) + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + assert copy_obj[2] is not original[2] + + +def test_deepcopy_with_frozenset(): + original = frozenset([1, 2, 3, (4, 5)]) + copy_obj = safe_deepcopy(original) + assert copy_obj == original + assert copy_obj is not original + + +def test_deepcopy_with_object(): + original = NormalObject(10) + copy_obj = safe_deepcopy(original) + assert copy_obj.value == original.value + assert copy_obj is not original + assert copy_obj.nested is not original.nested + + +def test_deepcopy_with_custom_deepcopy_fallback(): + original = {"origin": NormalObject(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj["origin"].value == original["origin"].value + + +def test_shallow_copy_fallback(): + original = {"origin": NonDeepcopyable(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj["origin"].value == original["origin"].value + + +def test_circular_reference(): + original = [] + original.append(original) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + assert copy_obj[0] is copy_obj + + + + +def test_deepcopy_object_without_dict(): + original = {"origin": WithoutDict(10)} + copy_obj = safe_deepcopy(original) + assert copy_obj["origin"].value == original["origin"].value + assert copy_obj is not original + assert copy_obj["origin"] is original["origin"] + assert ( + hasattr(copy_obj["origin"], "__dict__") is False + ) # Ensure __dict__ is not present + + original = [WithoutDict(10)] + copy_obj = safe_deepcopy(original) + assert copy_obj[0].value == original[0].value + assert copy_obj is not original + assert copy_obj[0] is original[0] + + original = (WithoutDict(10),) + copy_obj = safe_deepcopy(original) + assert copy_obj[0].value == original[0].value + assert copy_obj is not original + assert copy_obj[0] is original[0] + + original_item = WithoutDict(10) + original = set([original_item]) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + copy_obj_item = copy_obj.pop() + assert copy_obj_item.value == original_item.value + assert copy_obj_item is original_item + + original_item = WithoutDict(10) + original = frozenset([original_item]) + copy_obj = safe_deepcopy(original) + assert copy_obj is not original + copy_obj_item = list(copy_obj)[0] + assert copy_obj_item.value == original_item.value + assert copy_obj_item is original_item + +def test_unhandled_type(): + with pytest.raises(DeepCopyError): + original = {"origin": NonCopyableObject(10)} + copy_obj = safe_deepcopy(original) + +def test_client(): + llm_instance_config = { + "model": "moonshot-v1-8k", + "base_url": "https://api.moonshot.cn/v1", + "moonshot_api_key": "xxx", + } + + from langchain_community.chat_models.moonshot import MoonshotChat + + llm_model_instance = MoonshotChat(**llm_instance_config) + copy_obj = safe_deepcopy(llm_model_instance) + + assert copy_obj + assert hasattr(copy_obj, 'callbacks') + +def test_circular_reference_in_dict(): + original = {} + original['self'] = original # Create a circular reference + copy_obj = safe_deepcopy(original) + + # Check that the copy is a different object + assert copy_obj is not original + # Check that the circular reference is maintained in the copy + assert copy_obj['self'] is copy_obj + +def test_with_pydantic(): + original = PydantObject(value=1) + copy_obj = safe_deepcopy(original) + assert copy_obj.value == original.value + assert copy_obj is not original