Skip to content

Commit d3dee75

Browse files
Improve code for more clarity
1 parent ea5e5a9 commit d3dee75

File tree

2 files changed

+59
-44
lines changed

2 files changed

+59
-44
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -400,23 +400,26 @@ def _enforce_nodes(
400400
For each valid node, filter out properties not present in the schema.
401401
Remove a node if it ends up with no valid properties.
402402
"""
403+
if self.enforce_schema != SchemaEnforcementMode.STRICT:
404+
return extracted_nodes
405+
403406
valid_nodes = []
404-
if self.enforce_schema == SchemaEnforcementMode.STRICT:
405-
for node in extracted_nodes:
406-
if node.label in schema.entities:
407-
schema_entity = schema.entities[node.label]
408-
filtered_props = self._enforce_properties(
409-
node.properties,
410-
schema_entity["properties"])
411-
if filtered_props:
412-
# keep node only if it has at least one valid property
413-
new_node = Neo4jNode(
414-
id=node.id,
415-
label=node.label,
416-
properties=filtered_props,
417-
embedding_properties=node.embedding_properties,
418-
)
419-
valid_nodes.append(new_node)
407+
408+
for node in extracted_nodes:
409+
schema_entity = schema.entities.get(node.label)
410+
if not schema_entity:
411+
continue
412+
allowed_props = schema_entity.get("properties", {})
413+
filtered_props = self._enforce_properties(node.properties, allowed_props)
414+
if filtered_props:
415+
valid_nodes.append(
416+
Neo4jNode(
417+
id=node.id,
418+
label=node.label,
419+
properties=filtered_props,
420+
embedding_properties=node.embedding_properties,
421+
)
422+
)
420423

421424
return valid_nodes
422425

@@ -433,31 +436,43 @@ def _enforce_relationships(
433436
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
434437
For each valid relationship, filter out properties not present in the schema.
435438
"""
439+
if self.enforce_schema != SchemaEnforcementMode.STRICT:
440+
return extracted_relationships
441+
436442
valid_rels = []
437-
if self.enforce_schema == SchemaEnforcementMode.STRICT:
438-
valid_nodes = {node.id: node.label for node in filtered_nodes}
439-
for rel in extracted_relationships:
440-
# keep relationship if it conforms with the schema
441-
if rel.type in schema.relations:
442-
if (rel.start_node_id in valid_nodes and
443-
rel.end_node_id in valid_nodes):
444-
start_node_label = valid_nodes[rel.start_node_id]
445-
end_node_label = valid_nodes[rel.end_node_id]
446-
if (not schema.potential_schema or
447-
(start_node_label, rel.type, end_node_label) in
448-
schema.potential_schema):
449-
schema_relation = schema.relations[rel.type]
450-
filtered_props = self._enforce_properties(
451-
rel.properties,
452-
schema_relation["properties"])
453-
new_rel = Neo4jRelationship(
454-
start_node_id=rel.start_node_id,
455-
end_node_id=rel.end_node_id,
456-
type=rel.type,
457-
properties=filtered_props,
458-
embedding_properties=rel.embedding_properties,
459-
)
460-
valid_rels.append(new_rel)
443+
444+
valid_nodes = {node.id: node.label for node in filtered_nodes}
445+
446+
potential_schema = schema.potential_schema
447+
448+
for rel in extracted_relationships:
449+
schema_relation = schema.relations.get(rel.type)
450+
if not schema_relation:
451+
continue
452+
453+
if (rel.start_node_id not in valid_nodes or
454+
rel.end_node_id not in valid_nodes):
455+
continue
456+
457+
start_label = valid_nodes[rel.start_node_id]
458+
end_label = valid_nodes[rel.end_node_id]
459+
460+
if (potential_schema and
461+
(start_label, rel.type, end_label) not in potential_schema):
462+
continue
463+
464+
allowed_props = schema_relation.get("properties", {})
465+
filtered_props = self._enforce_properties(rel.properties, allowed_props)
466+
467+
valid_rels.append(
468+
Neo4jRelationship(
469+
start_node_id=rel.start_node_id,
470+
end_node_id=rel.end_node_id,
471+
type=rel.type,
472+
properties=filtered_props,
473+
embedding_properties=rel.embedding_properties,
474+
)
475+
)
461476

462477
return valid_rels
463478

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props():
349349
create_lexical_graph=False,
350350
enforce_schema=SchemaEnforcementMode.STRICT)
351351

352-
schema = SchemaConfig(entities={"Person": {"label": "Person", "properties": []}},
352+
schema = SchemaConfig(entities={"Person": {"label": "Person"}},
353353
relations={},
354354
potential_schema=[])
355355

@@ -378,7 +378,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types():
378378
schema = SchemaConfig(
379379
entities={"Person": {"label": "Person",
380380
"properties": [{"name": "name", "type": "STRING"}]}},
381-
relations={"LIKES": {"label": "LIKES", "properties": []}},
381+
relations={"LIKES": {"label": "LIKES"}},
382382
potential_schema=[])
383383

384384
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
@@ -409,7 +409,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node()
409409
"properties": [{"name": "name", "type": "STRING"}]},
410410
"City": {"label": "City",
411411
"properties": [{"name": "name", "type": "STRING"}]}},
412-
relations={"LIVES_IN": {"label": "LIVES_IN", "properties": []}},
412+
relations={"LIVES_IN": {"label": "LIVES_IN"}},
413413
potential_schema=[("Person", "LIVES_IN", "City")])
414414

415415
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
@@ -470,7 +470,7 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes():
470470
schema = SchemaConfig(
471471
entities={"Person": {"label": "Person",
472472
"properties": [{"name": "name", "type": "STRING"}]}},
473-
relations={"LIKES": {"label": "LIKES", "properties": []}},
473+
relations={"LIKES": {"label": "LIKES"}},
474474
potential_schema=[("Person", "LIKES", "Person")])
475475

476476
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])

0 commit comments

Comments
 (0)