Skip to content

Commit 6820700

Browse files
committed
fix: config loading after module reload
1 parent 75ef7e4 commit 6820700

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,4 @@
306306

307307
- Updated documentation to include new custom exceptions.
308308
- Improved the use of Pydantic for input data validation for retriever objects.
309+
- Fixed config loading after module reload (usage in jupyter notebooks)

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import BaseModel
2222

2323
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
24+
from neo4j_graphrag.utils.validation import issubclass_safe
2425

2526

2627
class DataModel(BaseModel):
@@ -52,7 +53,7 @@ def __new__(
5253
f"The run method return type must be annotated in {name}"
5354
)
5455
# the type hint must be a subclass of DataModel
55-
if not issubclass(return_model, DataModel):
56+
if not issubclass_safe(return_model, DataModel):
5657
raise PipelineDefinitionError(
5758
f"The run method must return a subclass of DataModel in {name}"
5859
)

src/neo4j_graphrag/experimental/pipeline/config/object_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
ParamConfig,
5757
)
5858
from neo4j_graphrag.llm import LLMInterface
59+
from neo4j_graphrag.utils.validation import issubclass_safe
60+
5961

6062
logger = logging.getLogger(__name__)
6163

@@ -131,9 +133,9 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
131133
self._global_data = resolved_data or {}
132134
logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
133135
if self.class_ is None:
134-
raise ValueError(f"`class_` is not required to parse object {self}")
136+
raise ValueError(f"`class_` is required to parse object {self}")
135137
klass = self._get_class(self.class_, self.get_module())
136-
if not issubclass(klass, self.get_interface()):
138+
if not issubclass_safe(klass, self.get_interface()):
137139
raise ValueError(
138140
f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'"
139141
)

src/neo4j_graphrag/utils/validation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,29 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Optional
17+
import importlib
18+
from typing import Optional, Tuple, Type, Union
1819

1920

2021
def validate_search_query_input(
2122
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
2223
) -> None:
2324
if not (bool(query_vector) ^ bool(query_text)):
2425
raise ValueError("You must provide exactly one of query_vector or query_text.")
26+
27+
28+
def issubclass_safe(cls: Type, class_or_tuple: Union[Type, Tuple[Type]]) -> bool:
29+
if isinstance(class_or_tuple, tuple):
30+
return any(issubclass_safe(cls, base) for base in class_or_tuple)
31+
32+
if issubclass(cls, class_or_tuple):
33+
return True
34+
35+
# Handle case where module was reloaded
36+
cls_module = importlib.import_module(cls.__module__)
37+
# Get the latest version of the base class from the module
38+
latest_base = getattr(cls_module, class_or_tuple.__name__, None)
39+
if issubclass(cls, latest_base):
40+
return True
41+
42+
return False

tests/unit/experimental/pipeline/config/test_object_config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import importlib
16+
import sys
17+
from abc import ABC
18+
from typing import ClassVar
1519
from unittest.mock import patch
1620

1721
import neo4j
@@ -58,6 +62,32 @@ def test_get_class_wrong_path() -> None:
5862
c._get_class("MyClass")
5963

6064

65+
class _MyClass:
66+
def __init__(self, param: str) -> None:
67+
self.param = param
68+
69+
70+
class _MyInterface(ABC): ...
71+
72+
73+
def test_parse_after_module_reload() -> None:
74+
class MyClassConfig(ObjectConfig[_MyClass]):
75+
DEFAULT_MODULE: ClassVar[str] = __name__
76+
INTERFACE: ClassVar[type] = _MyClass
77+
78+
param_value = "value"
79+
config = {
80+
"class_": f"{__name__}.{_MyClass.__name__}",
81+
"params_": {"param": param_value},
82+
}
83+
config = MyClassConfig(**config)
84+
importlib.reload(sys.modules[__name__])
85+
86+
my_obj = config.parse()
87+
assert isinstance(my_obj, _MyClass)
88+
assert my_obj.param == param_value
89+
90+
6191
def test_neo4j_driver_config() -> None:
6292
config = Neo4jDriverConfig.model_validate(
6393
{

0 commit comments

Comments
 (0)