Skip to content

Commit 6170e61

Browse files
Code cleanups
1 parent 4b32e45 commit 6170e61

File tree

3 files changed

+54
-53
lines changed

3 files changed

+54
-53
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,9 @@ def _enforce_nodes(
405405
for node in extracted_nodes:
406406
if node.label in schema.entities:
407407
schema_entity = schema.entities[node.label]
408-
filtered_props = self._enforce_properties(node.properties,
409-
schema_entity)
408+
filtered_props = self._enforce_properties(
409+
node.properties,
410+
schema_entity["properties"])
410411
if filtered_props:
411412
# keep node only if it has at least one valid property
412413
new_node = Neo4jNode(
@@ -416,8 +417,7 @@ def _enforce_nodes(
416417
embedding_properties=node.embedding_properties,
417418
)
418419
valid_nodes.append(new_node)
419-
# elif self.enforce_schema == SchemaEnforcementMode.OPEN:
420-
# future logic
420+
421421
return valid_nodes
422422

423423
def _enforce_relationships(
@@ -435,22 +435,21 @@ def _enforce_relationships(
435435
"""
436436
valid_rels = []
437437
if self.enforce_schema == SchemaEnforcementMode.STRICT:
438-
valid_node_ids = {node.id for node in filtered_nodes}
438+
valid_nodes = {node.id: node.label for node in filtered_nodes}
439439
for rel in extracted_relationships:
440440
# keep relationship if it conforms with the schema
441441
if rel.type in schema.relations:
442-
if (rel.start_node_id in valid_node_ids and
443-
rel.end_node_id in valid_node_ids):
444-
start_node_label = self._get_node_label(rel.start_node_id,
445-
filtered_nodes)
446-
end_node_label = self._get_node_label(rel.end_node_id,
447-
filtered_nodes)
442+
if (rel.start_node_id in valid_nodes and
443+
rel.end_node_id in valid_nodes):
444+
start_node_label = valid_nodes[rel.start_node_id]
445+
end_node_label = valid_nodes[rel.end_node_id]
448446
if (not schema.potential_schema or
449447
(start_node_label, rel.type, end_node_label) in
450448
schema.potential_schema):
451449
schema_relation = schema.relations[rel.type]
452-
filtered_props = self._enforce_properties(rel.properties,
453-
schema_relation)
450+
filtered_props = self._enforce_properties(
451+
rel.properties,
452+
schema_relation["properties"])
454453
new_rel = Neo4jRelationship(
455454
start_node_id=rel.start_node_id,
456455
end_node_id=rel.end_node_id,
@@ -459,35 +458,22 @@ def _enforce_relationships(
459458
embedding_properties=rel.embedding_properties,
460459
)
461460
valid_rels.append(new_rel)
462-
# elif self.enforce_schema == SchemaEnforcementMode.OPEN:
463-
# future logic
461+
464462
return valid_rels
465463

466464
def _enforce_properties(
467465
self,
468466
properties: Dict[str, Any],
469-
valid_properties: Dict[str, Any]
467+
valid_properties: List[Dict[str, Any]]
470468
) -> Dict[str, Any]:
471469
"""
472470
Filter properties.
473471
Keep only those that exist in schema (i.e., valid properties).
474472
"""
473+
valid_prop_names = {prop["name"] for prop in valid_properties}
475474
return {
476475
key: value
477476
for key, value in properties.items()
478-
if key in valid_properties
477+
if key in valid_prop_names
479478
}
480479

481-
def _get_node_label(
482-
self,
483-
node_id: str,
484-
nodes: List[Neo4jNode]
485-
) -> str:
486-
"""
487-
Given a list of nodes, get the label of the node whose id matches the provided
488-
node id.
489-
"""
490-
for node in nodes:
491-
if node.id == node_id:
492-
return node.label
493-
return ""

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,3 @@ class GraphResult(DataModel):
176176
class SchemaEnforcementMode(str, Enum):
177177
NONE = "none"
178178
STRICT = "strict"
179-
# future possibility: OPEN = "open" -> ensure conformance of nodes/props/rels that
180-
# were listed in the schema but leave room for extras

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ async def test_extractor_no_schema_enforcement() -> None:
243243
create_lexical_graph=False,
244244
enforce_schema=SchemaEnforcementMode.NONE)
245245

246-
schema = SchemaConfig(entities={"Person": {"name": "STRING"}},
247-
relations={},
248-
potential_schema=[])
246+
schema = SchemaConfig(
247+
entities={"Person": {"label": "Person",
248+
"properties": [{"name": "name", "type": "STRING"}]}},
249+
relations={},
250+
potential_schema=[])
249251

250252
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
251253

@@ -290,9 +292,11 @@ async def test_extractor_schema_enforcement_invalid_nodes():
290292
create_lexical_graph=False,
291293
enforce_schema=SchemaEnforcementMode.STRICT)
292294

293-
schema = SchemaConfig(entities={"Person": {"name": "STRING"}},
294-
relations={},
295-
potential_schema=[])
295+
schema = SchemaConfig(
296+
entities={"Person": {"label": "Person",
297+
"properties": [{"name": "name", "type": "STRING"}]}},
298+
relations={},
299+
potential_schema=[])
296300

297301
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
298302

@@ -316,9 +320,12 @@ async def test_extraction_schema_enforcement_invalid_node_properties():
316320
create_lexical_graph=False,
317321
enforce_schema=SchemaEnforcementMode.STRICT)
318322

319-
schema = SchemaConfig(entities={"Person": {"name": str, "age": int}},
320-
relations={},
321-
potential_schema=[])
323+
schema = SchemaConfig(
324+
entities={"Person": {"label": "Person",
325+
"properties": [{"name": "name", "type": "STRING"},
326+
{"name": "age", "type": "INTEGER"}]}},
327+
relations={},
328+
potential_schema=[])
322329

323330
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
324331

@@ -342,7 +349,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props():
342349
create_lexical_graph=False,
343350
enforce_schema=SchemaEnforcementMode.STRICT)
344351

345-
schema = SchemaConfig(entities={"Person": {}},
352+
schema = SchemaConfig(entities={"Person": {"label": "Person", "properties": []}},
346353
relations={},
347354
potential_schema=[])
348355

@@ -368,9 +375,11 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types():
368375
create_lexical_graph=False,
369376
enforce_schema=SchemaEnforcementMode.STRICT)
370377

371-
schema = SchemaConfig(entities={"Person": {"name": str}},
372-
relations={"LIKES": {}},
373-
potential_schema=[])
378+
schema = SchemaConfig(
379+
entities={"Person": {"label": "Person",
380+
"properties": [{"name": "name", "type": "STRING"}]}},
381+
relations={"LIKES": {"label": "LIKES", "properties": []}},
382+
potential_schema=[])
374383

375384
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
376385

@@ -395,9 +404,13 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node()
395404
create_lexical_graph=False,
396405
enforce_schema=SchemaEnforcementMode.STRICT)
397406

398-
schema = SchemaConfig(entities={"Person": {"name": str}, "City": {"name": str}},
399-
relations={"LIVES_IN": {}},
400-
potential_schema=[("Person", "LIVES_IN", "City")])
407+
schema = SchemaConfig(
408+
entities={"Person": {"label": "Person",
409+
"properties": [{"name": "name", "type": "STRING"}]},
410+
"City": {"label": "City",
411+
"properties": [{"name": "name", "type": "STRING"}]}},
412+
relations={"LIVES_IN": {"label": "LIVES_IN", "properties": []}},
413+
potential_schema=[("Person", "LIVES_IN", "City")])
401414

402415
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
403416

@@ -422,8 +435,10 @@ async def test_extractor_schema_enforcement_invalid_relation_properties():
422435
enforce_schema=SchemaEnforcementMode.STRICT)
423436

424437
schema = SchemaConfig(
425-
entities={"Person": {"name": str}},
426-
relations={"LIKES": {"strength": str}},
438+
entities={"Person": {"label": "Person",
439+
"properties": [{"name": "name", "type": "STRING"}]}},
440+
relations={"LIKES": {"label": "LIKES",
441+
"properties": [{"name": "strength", "type": "STRING"}]}},
427442
potential_schema=[]
428443
)
429444

@@ -452,9 +467,11 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes():
452467
create_lexical_graph=False,
453468
enforce_schema=SchemaEnforcementMode.STRICT)
454469

455-
schema = SchemaConfig(entities={"Person": {"name": str}},
456-
relations={"LIKES": {}},
457-
potential_schema=[("Person", "LIKES", "Person")])
470+
schema = SchemaConfig(
471+
entities={"Person": {"label": "Person",
472+
"properties": [{"name": "name", "type": "STRING"}]}},
473+
relations={"LIKES": {"label": "LIKES", "properties": []}},
474+
potential_schema=[("Person", "LIKES", "Person")])
458475

459476
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
460477

0 commit comments

Comments
 (0)