Skip to content

Commit fd8d022

Browse files
Invert rel direction
1 parent 86503b4 commit fd8d022

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

docs/source/user_guide_kg_builder.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ This behaviour can be changed by using the `enforce_schema` flag in the `LLMEnti
847847
848848
In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned.
849849
Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned.
850+
If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted.
850851
If a node is left with no properties, it will be also pruned.
851852

852853
.. warning::

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def _enforce_relationships(
435435
Keep only those whose types are in schema, start/end node conform to schema,
436436
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
437437
For each valid relationship, filter out properties not present in the schema.
438+
If a relationship direct is incorrect, invert it.
438439
"""
439440
if self.enforce_schema != SchemaEnforcementMode.STRICT:
440441
return extracted_relationships
@@ -457,17 +458,22 @@ def _enforce_relationships(
457458
start_label = valid_nodes[rel.start_node_id]
458459
end_label = valid_nodes[rel.end_node_id]
459460

460-
if (potential_schema and
461-
(start_label, rel.type, end_label) not in potential_schema):
462-
continue
461+
tuple_valid = True
462+
if potential_schema:
463+
tuple_valid = (start_label, rel.type, end_label) in potential_schema
464+
reverse_tuple_valid = ((end_label, rel.type, start_label) in
465+
potential_schema)
466+
467+
if not tuple_valid and not reverse_tuple_valid:
468+
continue
463469

464470
allowed_props = schema_relation.get("properties", [])
465471
filtered_props = self._enforce_properties(rel.properties, allowed_props)
466472

467473
valid_rels.append(
468474
Neo4jRelationship(
469-
start_node_id=rel.start_node_id,
470-
end_node_id=rel.end_node_id,
475+
start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id,
476+
end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id,
471477
type=rel.type,
472478
properties=filtered_props,
473479
embedding_properties=rel.embedding_properties,

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,38 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes():
481481
assert len(result.relationships) == 0
482482

483483

484+
@pytest.mark.asyncio
485+
async def test_extractor_schema_enforcement_inverted_relation_direction():
486+
llm = MagicMock(spec=LLMInterface)
487+
llm.ainvoke.return_value = LLMResponse(
488+
content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},'
489+
'{"id":"2","label":"City","properties":{"name":"London"}}],'
490+
'"relationships":[{"start_node_id":"2","end_node_id":"1",'
491+
'"type":"LIVES_IN","properties":{}}]}'
492+
)
493+
494+
extractor = LLMEntityRelationExtractor(llm=llm,
495+
create_lexical_graph=False,
496+
enforce_schema=SchemaEnforcementMode.STRICT)
497+
498+
schema = SchemaConfig(
499+
entities={"Person": {"label": "Person",
500+
"properties": [{"name": "name", "type": "STRING"}]},
501+
"City": {"label": "City",
502+
"properties": [{"name": "name", "type": "STRING"}]}},
503+
relations={"LIVES_IN": {"label": "LIVES_IN"}},
504+
potential_schema=[("Person", "LIVES_IN", "City")])
505+
506+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
507+
508+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
509+
510+
assert len(result.nodes) == 2
511+
assert len(result.relationships) == 1
512+
assert result.relationships[0].start_node_id.split(":")[1] == "1"
513+
assert result.relationships[0].end_node_id.split(":")[1] == "2"
514+
515+
484516
def test_fix_invalid_json_empty_result() -> None:
485517
json_string = "invalid json"
486518

0 commit comments

Comments
 (0)