Skip to content

Commit 9b8b8e8

Browse files
authored
Strict mode behavior (#334)
* Strict mode: if node/relationship is in schema but no properties are defined, do not filter allowed properties * Ruff * Deal with missing relationships, add tests * Update CHANGELOG and doc
1 parent 5a7925d commit 9b8b8e8

File tree

5 files changed

+98
-10
lines changed

5 files changed

+98
-10
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44

55
### Added
66

7-
- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
7+
- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
88

99
### Fixed
1010

1111
- Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers.
1212

13+
### Changed
14+
15+
- Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema.
16+
1317

1418
## 1.7.0
1519

docs/source/user_guide_kg_builder.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,13 @@ Any relation whose start node or end node does not conform to the provided tuple
901901
If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted.
902902
If a node is left with no properties, it will be also pruned.
903903

904+
.. note::
905+
906+
If the input schema lacks a certain type of information, pruning is skipped.
907+
For example, if an entity is defined only by a label and has no properties,
908+
property pruning is not performed and all properties returned by the LLM are kept.
909+
910+
904911
.. warning::
905912

906913
Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied.

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,13 @@ def _enforce_nodes(
403403
schema_entity = schema.entities.get(node.label)
404404
if not schema_entity:
405405
continue
406-
allowed_props = schema_entity.get("properties", [])
407-
filtered_props = self._enforce_properties(node.properties, allowed_props)
406+
allowed_props = schema_entity.get("properties")
407+
if allowed_props:
408+
filtered_props = self._enforce_properties(
409+
node.properties, allowed_props
410+
)
411+
else:
412+
filtered_props = node.properties
408413
if filtered_props:
409414
valid_nodes.append(
410415
Neo4jNode(
@@ -434,16 +439,17 @@ def _enforce_relationships(
434439
if self.enforce_schema != SchemaEnforcementMode.STRICT:
435440
return extracted_relationships
436441

442+
if schema.relations is None:
443+
return extracted_relationships
444+
437445
valid_rels = []
438446

439447
valid_nodes = {node.id: node.label for node in filtered_nodes}
440448

441449
potential_schema = schema.potential_schema
442450

443451
for rel in extracted_relationships:
444-
schema_relation = (
445-
schema.relations.get(rel.type) if schema.relations else None
446-
)
452+
schema_relation = schema.relations.get(rel.type)
447453
if not schema_relation:
448454
continue
449455

@@ -468,8 +474,11 @@ def _enforce_relationships(
468474
if not tuple_valid and not reverse_tuple_valid:
469475
continue
470476

471-
allowed_props = schema_relation.get("properties", [])
472-
filtered_props = self._enforce_properties(rel.properties, allowed_props)
477+
allowed_props = schema_relation.get("properties")
478+
if allowed_props:
479+
filtered_props = self._enforce_properties(rel.properties, allowed_props)
480+
else:
481+
filtered_props = rel.properties
473482

474483
valid_rels.append(
475484
Neo4jRelationship(

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class SchemaConfig(DataModel):
109109
@model_validator(mode="before")
110110
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
111111
entities = data.get("entities", {}).keys()
112-
relations = data.get("relations", {}).keys()
112+
relations = (data.get("relations") or {}).keys()
113113
potential_schema = data.get("potential_schema", [])
114114

115115
if potential_schema:

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> No
374374

375375
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
376376

377-
assert len(result.nodes) == 0
377+
assert len(result.nodes) == 1
378378

379379

380380
@pytest.mark.asyncio
@@ -564,6 +564,74 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non
564564
assert result.relationships[0].end_node_id.split(":")[1] == "2"
565565

566566

567+
@pytest.mark.asyncio
568+
async def test_extractor_schema_enforcement_none_relationships_in_schema() -> None:
569+
llm = MagicMock(spec=LLMInterface)
570+
llm.ainvoke.return_value = LLMResponse(
571+
content='{"nodes":[{"id":"1","label":"Person","properties":'
572+
'{"name":"Alice"}},{"id":"2","label":"Person","properties":'
573+
'{"name":"Bob"}}],'
574+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
575+
'"type":"FRIENDS_WITH","properties":{}}]}'
576+
)
577+
578+
extractor = LLMEntityRelationExtractor(
579+
llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT
580+
)
581+
582+
schema = SchemaConfig(
583+
entities={
584+
"Person": {
585+
"label": "Person",
586+
"properties": [{"name": "name", "type": "STRING"}],
587+
}
588+
},
589+
relations=None,
590+
potential_schema=None,
591+
)
592+
593+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
594+
595+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
596+
597+
assert len(result.nodes) == 2
598+
assert len(result.relationships) == 1
599+
assert result.relationships[0].type == "FRIENDS_WITH"
600+
601+
602+
@pytest.mark.asyncio
603+
async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> None:
604+
llm = MagicMock(spec=LLMInterface)
605+
llm.ainvoke.return_value = LLMResponse(
606+
content='{"nodes":[{"id":"1","label":"Person","properties":'
607+
'{"name":"Alice"}},{"id":"2","label":"Person","properties":'
608+
'{"name":"Bob"}}],'
609+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
610+
'"type":"FRIENDS_WITH","properties":{}}]}'
611+
)
612+
613+
extractor = LLMEntityRelationExtractor(
614+
llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT
615+
)
616+
617+
schema = SchemaConfig(
618+
entities={
619+
"Person": {
620+
"label": "Person",
621+
"properties": [{"name": "name", "type": "STRING"}],
622+
}
623+
},
624+
relations={},
625+
potential_schema=None,
626+
)
627+
628+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
629+
630+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
631+
632+
assert len(result.relationships) == 0
633+
634+
567635
def test_fix_invalid_json_empty_result() -> None:
568636
json_string = "invalid json"
569637

0 commit comments

Comments
 (0)