Skip to content

Commit bbf066b

Browse files
Merge branch 'main' into kg-builder-strict-mode
2 parents 1dc0c42 + c96af5c commit bbf066b

File tree

19 files changed

+1048
-275
lines changed

19 files changed

+1048
-275
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- Added optional schema enforcement as a validation layer after entity and relation extraction.
8+
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.
89

910
## 1.5.0
1011

@@ -310,3 +311,4 @@
310311

311312
- Updated documentation to include new custom exceptions.
312313
- Improved the use of Pydantic for input data validation for retriever objects.
314+
- Fixed config loading after module reload (usage in jupyter notebooks)

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ When you're finished with your changes, create a pull request (PR) using the fol
371371

372372
## 🧪 Tests
373373

374+
To be able to run all tests, all extra packages needs to be installed.
375+
This is achieved by:
376+
377+
```bash
378+
poetry install --all-extras
379+
```
380+
374381
### Unit Tests
375382

376383
Install the project dependencies then run the following command to run the unit tests locally:

docs/source/api.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,16 @@ Errors
445445

446446
* :class:`neo4j_graphrag.exceptions.LLMGenerationError`
447447

448+
* :class:`neo4j_graphrag.exceptions.SchemaValidationError`
449+
450+
* :class:`neo4j_graphrag.exceptions.PdfLoaderError`
451+
452+
* :class:`neo4j_graphrag.exceptions.PromptMissingPlaceholderError`
453+
454+
* :class:`neo4j_graphrag.exceptions.InvalidHybridSearchRankerError`
455+
456+
* :class:`neo4j_graphrag.exceptions.SearchQueryParseError`
457+
448458
* :class:`neo4j_graphrag.experimental.pipeline.exceptions.PipelineDefinitionError`
449459

450460
* :class:`neo4j_graphrag.experimental.pipeline.exceptions.PipelineMissingDependencyError`
@@ -559,6 +569,41 @@ LLMGenerationError
559569
:show-inheritance:
560570

561571

572+
SchemaValidationError
573+
=====================
574+
575+
.. autoclass:: neo4j_graphrag.exceptions.SchemaValidationError
576+
:show-inheritance:
577+
578+
579+
PdfLoaderError
580+
==============
581+
582+
.. autoclass:: neo4j_graphrag.exceptions.PdfLoaderError
583+
:show-inheritance:
584+
585+
586+
PromptMissingPlaceholderError
587+
=============================
588+
589+
.. autoclass:: neo4j_graphrag.exceptions.PromptMissingPlaceholderError
590+
:show-inheritance:
591+
592+
593+
InvalidHybridSearchRankerError
594+
==============================
595+
596+
.. autoclass:: neo4j_graphrag.exceptions.InvalidHybridSearchRankerError
597+
:show-inheritance:
598+
599+
600+
SearchQueryParseError
601+
=====================
602+
603+
.. autoclass:: neo4j_graphrag.exceptions.SearchQueryParseError
604+
:show-inheritance:
605+
606+
562607
PipelineDefinitionError
563608
=======================
564609

docs/source/user_guide_rag.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,10 @@ it can be queried using the following:
223223
.. code:: python
224224
225225
from neo4j_graphrag.llm import OllamaLLM
226-
llm = OllamaLLM(model_name="orca-mini")
226+
llm = OllamaLLM(
227+
model_name="orca-mini",
228+
# host="...", # when using a remote server
229+
)
227230
llm.invoke("say something")
228231
229232

examples/customize/embeddings/ollama_embeddings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
embeder = OllamaEmbeddings(
88
model="<model_name>",
9+
# host="...", # if using a remote server
910
)
1011
res = embeder.embed_query("my question")
1112
print(res[:10])

examples/customize/llms/ollama_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
llm = OllamaLLM(
88
model_name="<model_name>",
9+
# host="...", # if using a remote server
910
)
1011
res: LLMResponse = llm.invoke("What is the additive color model?")
1112
print(res.content)

poetry.lock

Lines changed: 650 additions & 218 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ python = "^3.9.0"
3232
neo4j = "^5.17.0"
3333
pydantic = "^2.6.3"
3434
fsspec = "^2024.9.0"
35-
pypdf = "^4.3.1"
36-
json-repair = "^0.30.2"
35+
pypdf = "^5.1.0"
36+
json-repair = "^0.39.1"
3737
pyyaml = "^6.0.2"
3838
types-pyyaml = "^6.0.12.20240917"
3939
# optional deps
@@ -48,7 +48,7 @@ google-cloud-aiplatform = {version = "^1.66.0", optional = true }
4848
cohere = {version = "^5.9.0", optional = true}
4949
mistralai = {version = "^1.0.3", optional = true}
5050
qdrant-client = {version = "^1.11.3", optional = true}
51-
llama-index = {version = "^0.10.55", optional = true }
51+
llama-index = {version = "^0.12.0", optional = true }
5252
openai = {version = "^1.51.1", optional = true }
5353
anthropic = { version = "^0.36.0", optional = true}
5454
sentence-transformers = {version = "^3.0.0", optional = true }

src/neo4j_graphrag/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,7 @@ class PromptMissingPlaceholderError(Neo4jGraphRagError):
128128

129129
class InvalidHybridSearchRankerError(Neo4jGraphRagError):
130130
"""Exception raised when an invalid ranker type for Hybrid Search is provided."""
131+
132+
133+
class SearchQueryParseError(Neo4jGraphRagError):
134+
"""Exception raised when there is a query parse error in the text search string."""

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pydantic import BaseModel
2222

2323
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
24+
from neo4j_graphrag.utils.validation import issubclass_safe
2425

2526

2627
class DataModel(BaseModel):
@@ -52,7 +53,7 @@ def __new__(
5253
f"The run method return type must be annotated in {name}"
5354
)
5455
# the type hint must be a subclass of DataModel
55-
if not issubclass(return_model, DataModel):
56+
if not issubclass_safe(return_model, DataModel):
5657
raise PipelineDefinitionError(
5758
f"The run method must return a subclass of DataModel in {name}"
5859
)

src/neo4j_graphrag/experimental/pipeline/config/object_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
ParamConfig,
5757
)
5858
from neo4j_graphrag.llm import LLMInterface
59+
from neo4j_graphrag.utils.validation import issubclass_safe
60+
5961

6062
logger = logging.getLogger(__name__)
6163

@@ -131,9 +133,9 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
131133
self._global_data = resolved_data or {}
132134
logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
133135
if self.class_ is None:
134-
raise ValueError(f"`class_` is not required to parse object {self}")
136+
raise ValueError(f"`class_` is required to parse object {self}")
135137
klass = self._get_class(self.class_, self.get_module())
136-
if not issubclass(klass, self.get_interface()):
138+
if not issubclass_safe(klass, self.get_interface()):
137139
raise ValueError(
138140
f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'"
139141
)

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
9595
async with asyncio.Lock():
9696
current_status = await self.get_status_for_component(task_name)
9797
if status == current_status:
98-
raise PipelineStatusUpdateError(f"Status is already '{status}'")
99-
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
100-
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
98+
raise PipelineStatusUpdateError(f"Status is already {status}")
99+
if status not in current_status.possible_next_status():
100+
raise PipelineStatusUpdateError(
101+
f"Can't go from {current_status} to {status}"
102+
)
101103
return await self.pipeline.store.add_status_for_component(
102104
self.run_id, task_name, status.value
103105
)

src/neo4j_graphrag/experimental/pipeline/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ class RunStatus(enum.Enum):
5454
RUNNING = "RUNNING"
5555
DONE = "DONE"
5656

57+
def possible_next_status(self) -> list[RunStatus]:
58+
if self == RunStatus.UNKNOWN:
59+
return [RunStatus.RUNNING]
60+
if self == RunStatus.RUNNING:
61+
return [RunStatus.DONE]
62+
if self == RunStatus.DONE:
63+
return []
64+
return []
65+
5766

5867
class RunResult(BaseModel):
5968
status: RunStatus = RunStatus.DONE

src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_search_results(
211211
for point in points:
212212
assert point.payload is not None
213213
result_tuples.append(
214-
[f"{point.payload[self.id_property_external]}", point.score]
214+
[point.payload.get(self.id_property_external, point.id), point.score]
215215
)
216216

217217
search_query = get_match_query(

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
EmbeddingRequiredError,
2727
RetrieverInitializationError,
2828
SearchValidationError,
29+
SearchQueryParseError,
2930
)
3031
from neo4j_graphrag.neo4j_queries import get_search_query
3132
from neo4j_graphrag.retrievers.base import Retriever
@@ -218,12 +219,19 @@ def get_search_results(
218219
logger.debug("HybridRetriever Cypher parameters: %s", sanitized_parameters)
219220
logger.debug("HybridRetriever Cypher query: %s", search_query)
220221

221-
records, _, _ = self.driver.execute_query(
222-
search_query,
223-
parameters,
224-
database_=self.neo4j_database,
225-
routing_=neo4j.RoutingControl.READ,
226-
)
222+
try:
223+
records, _, _ = self.driver.execute_query(
224+
search_query,
225+
parameters,
226+
database_=self.neo4j_database,
227+
routing_=neo4j.RoutingControl.READ,
228+
)
229+
except neo4j.exceptions.ClientError as e:
230+
if "org.apache.lucene.queryparser.classic.ParseException" in str(e):
231+
raise SearchQueryParseError(
232+
f"Invalid Lucene query generated from query_text: {query_text}"
233+
) from e
234+
raise
227235
return RawSearchResult(
228236
records=records,
229237
)
@@ -395,12 +403,19 @@ def get_search_results(
395403
logger.debug("HybridRetriever Cypher parameters: %s", sanitized_parameters)
396404
logger.debug("HybridRetriever Cypher query: %s", search_query)
397405

398-
records, _, _ = self.driver.execute_query(
399-
search_query,
400-
parameters,
401-
database_=self.neo4j_database,
402-
routing_=neo4j.RoutingControl.READ,
403-
)
406+
try:
407+
records, _, _ = self.driver.execute_query(
408+
search_query,
409+
parameters,
410+
database_=self.neo4j_database,
411+
routing_=neo4j.RoutingControl.READ,
412+
)
413+
except neo4j.exceptions.ClientError as e:
414+
if "org.apache.lucene.queryparser.classic.ParseException" in str(e):
415+
raise SearchQueryParseError(
416+
f"Invalid Lucene query generated from query_text: {query_text}"
417+
) from e
418+
raise
404419
return RawSearchResult(
405420
records=records,
406421
)

src/neo4j_graphrag/utils/validation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,32 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Optional
17+
import importlib
18+
from typing import Optional, Tuple, Union, cast, Type
1819

1920

2021
def validate_search_query_input(
2122
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
2223
) -> None:
2324
if not (bool(query_vector) ^ bool(query_text)):
2425
raise ValueError("You must provide exactly one of query_vector or query_text.")
26+
27+
28+
def issubclass_safe(
29+
cls: Type[object], class_or_tuple: Union[Type[object], Tuple[Type[object]]]
30+
) -> bool:
31+
if isinstance(class_or_tuple, tuple):
32+
return any(issubclass_safe(cls, base) for base in class_or_tuple)
33+
34+
if issubclass(cls, class_or_tuple):
35+
return True
36+
37+
# Handle case where module was reloaded
38+
cls_module = importlib.import_module(cls.__module__)
39+
# Get the latest version of the base class from the module
40+
latest_base = getattr(cls_module, class_or_tuple.__name__, None)
41+
latest_base = cast(Union[tuple[Type[object], ...], Type[object]], latest_base)
42+
if issubclass(cls, latest_base):
43+
return True
44+
45+
return False

tests/unit/experimental/pipeline/config/test_object_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import importlib
16+
import sys
17+
from abc import ABC
18+
from typing import ClassVar
1519
from unittest.mock import patch
1620

1721
import neo4j
@@ -58,6 +62,30 @@ def test_get_class_wrong_path() -> None:
5862
c._get_class("MyClass")
5963

6064

65+
class _MyClass:
66+
def __init__(self, param: str) -> None:
67+
self.param = param
68+
69+
70+
class _MyInterface(ABC): ...
71+
72+
73+
def test_parse_after_module_reload() -> None:
74+
class MyClassConfig(ObjectConfig[_MyClass]):
75+
DEFAULT_MODULE: ClassVar[str] = __name__
76+
INTERFACE: ClassVar[type] = _MyClass
77+
78+
param_value = "value"
79+
config = MyClassConfig.model_validate(
80+
{"class_": f"{__name__}.{_MyClass.__name__}", "params_": {"param": param_value}}
81+
)
82+
importlib.reload(sys.modules[__name__])
83+
84+
my_obj = config.parse()
85+
assert isinstance(my_obj, _MyClass)
86+
assert my_obj.param == param_value
87+
88+
6189
def test_neo4j_driver_config() -> None:
6290
config = Neo4jDriverConfig.model_validate(
6391
{

0 commit comments

Comments
 (0)