From 85e9d1fac21812474720baf965b8387b10119292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Thu, 27 Feb 2025 14:49:49 +0100 Subject: [PATCH] fix: config loading after module reload --- CHANGELOG.md | 1 + .../experimental/pipeline/component.py | 3 +- .../pipeline/config/object_config.py | 6 ++-- src/neo4j_graphrag/utils/validation.py | 23 ++++++++++++++- .../pipeline/config/test_object_config.py | 28 +++++++++++++++++++ 5 files changed, 57 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8123110bf..09d264afc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -309,3 +309,4 @@ - Updated documentation to include new custom exceptions. - Improved the use of Pydantic for input data validation for retriever objects. +- Fixed config loading after module reload (usage in jupyter notebooks) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 84cd5bc0c..90877738f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.utils.validation import issubclass_safe class DataModel(BaseModel): @@ -52,7 +53,7 @@ def __new__( f"The run method return type must be annotated in {name}" ) # the type hint must be a subclass of DataModel - if not issubclass(return_model, DataModel): + if not issubclass_safe(return_model, DataModel): raise PipelineDefinitionError( f"The run method must return a subclass of DataModel in {name}" ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py index deeee2003..95d69888d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -56,6 +56,8 @@ ParamConfig, ) from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.utils.validation import issubclass_safe + logger = logging.getLogger(__name__) @@ -131,9 +133,9 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> T: self._global_data = resolved_data or {} logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}") if self.class_ is None: - raise ValueError(f"`class_` is not required to parse object {self}") + raise ValueError(f"`class_` is required to parse object {self}") klass = self._get_class(self.class_, self.get_module()) - if not issubclass(klass, self.get_interface()): + if not issubclass_safe(klass, self.get_interface()): raise ValueError( f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'" ) diff --git a/src/neo4j_graphrag/utils/validation.py b/src/neo4j_graphrag/utils/validation.py index e86f7588a..f1cd4660f 100644 --- a/src/neo4j_graphrag/utils/validation.py +++ b/src/neo4j_graphrag/utils/validation.py @@ -14,7 +14,8 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +import importlib +from typing import Optional, Tuple, Union, cast, Type def validate_search_query_input( @@ -22,3 +23,23 @@ def validate_search_query_input( ) -> None: if not (bool(query_vector) ^ bool(query_text)): raise ValueError("You must provide exactly one of query_vector or query_text.") + + +def issubclass_safe( + cls: Type[object], class_or_tuple: Union[Type[object], Tuple[Type[object]]] +) -> bool: + if isinstance(class_or_tuple, tuple): + return any(issubclass_safe(cls, base) for base in class_or_tuple) + + if issubclass(cls, class_or_tuple): + return True + + # Handle case where module was reloaded + cls_module = importlib.import_module(cls.__module__) + # Get the latest version of the base class from the module + latest_base = getattr(cls_module, class_or_tuple.__name__, None) + latest_base = cast(Union[tuple[Type[object], ...], Type[object]], latest_base) + if issubclass(cls, latest_base): + return True + + return False diff --git a/tests/unit/experimental/pipeline/config/test_object_config.py b/tests/unit/experimental/pipeline/config/test_object_config.py index baa3c7672..c39dcb808 100644 --- a/tests/unit/experimental/pipeline/config/test_object_config.py +++ b/tests/unit/experimental/pipeline/config/test_object_config.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import sys +from abc import ABC +from typing import ClassVar from unittest.mock import patch import neo4j @@ -58,6 +62,30 @@ def test_get_class_wrong_path() -> None: c._get_class("MyClass") +class _MyClass: + def __init__(self, param: str) -> None: + self.param = param + + +class _MyInterface(ABC): ... + + +def test_parse_after_module_reload() -> None: + class MyClassConfig(ObjectConfig[_MyClass]): + DEFAULT_MODULE: ClassVar[str] = __name__ + INTERFACE: ClassVar[type] = _MyClass + + param_value = "value" + config = MyClassConfig.model_validate( + {"class_": f"{__name__}.{_MyClass.__name__}", "params_": {"param": param_value}} + ) + importlib.reload(sys.modules[__name__]) + + my_obj = config.parse() + assert isinstance(my_obj, _MyClass) + assert my_obj.param == param_value + + def test_neo4j_driver_config() -> None: config = Neo4jDriverConfig.model_validate( {