@@ -400,23 +400,26 @@ def _enforce_nodes(
400
400
For each valid node, filter out properties not present in the schema.
401
401
Remove a node if it ends up with no valid properties.
402
402
"""
403
+ if self .enforce_schema != SchemaEnforcementMode .STRICT :
404
+ return extracted_nodes
405
+
403
406
valid_nodes = []
404
- if self . enforce_schema == SchemaEnforcementMode . STRICT :
405
- for node in extracted_nodes :
406
- if node . label in schema .entities :
407
- schema_entity = schema . entities [ node . label ]
408
- filtered_props = self . _enforce_properties (
409
- node . properties ,
410
- schema_entity [ " properties" ] )
411
- if filtered_props :
412
- # keep node only if it has at least one valid property
413
- new_node = Neo4jNode (
414
- id = node .id ,
415
- label = node .label ,
416
- properties = filtered_props ,
417
- embedding_properties = node .embedding_properties ,
418
- )
419
- valid_nodes . append ( new_node )
407
+
408
+ for node in extracted_nodes :
409
+ schema_entity = schema .entities . get ( node . label )
410
+ if not schema_entity :
411
+ continue
412
+ allowed_props = schema_entity . get ( " properties" , {})
413
+ filtered_props = self . _enforce_properties ( node . properties , allowed_props )
414
+ if filtered_props :
415
+ valid_nodes . append (
416
+ Neo4jNode (
417
+ id = node .id ,
418
+ label = node .label ,
419
+ properties = filtered_props ,
420
+ embedding_properties = node .embedding_properties ,
421
+ )
422
+ )
420
423
421
424
return valid_nodes
422
425
@@ -433,31 +436,43 @@ def _enforce_relationships(
433
436
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
434
437
For each valid relationship, filter out properties not present in the schema.
435
438
"""
439
+ if self .enforce_schema != SchemaEnforcementMode .STRICT :
440
+ return extracted_relationships
441
+
436
442
valid_rels = []
437
- if self .enforce_schema == SchemaEnforcementMode .STRICT :
438
- valid_nodes = {node .id : node .label for node in filtered_nodes }
439
- for rel in extracted_relationships :
440
- # keep relationship if it conforms with the schema
441
- if rel .type in schema .relations :
442
- if (rel .start_node_id in valid_nodes and
443
- rel .end_node_id in valid_nodes ):
444
- start_node_label = valid_nodes [rel .start_node_id ]
445
- end_node_label = valid_nodes [rel .end_node_id ]
446
- if (not schema .potential_schema or
447
- (start_node_label , rel .type , end_node_label ) in
448
- schema .potential_schema ):
449
- schema_relation = schema .relations [rel .type ]
450
- filtered_props = self ._enforce_properties (
451
- rel .properties ,
452
- schema_relation ["properties" ])
453
- new_rel = Neo4jRelationship (
454
- start_node_id = rel .start_node_id ,
455
- end_node_id = rel .end_node_id ,
456
- type = rel .type ,
457
- properties = filtered_props ,
458
- embedding_properties = rel .embedding_properties ,
459
- )
460
- valid_rels .append (new_rel )
443
+
444
+ valid_nodes = {node .id : node .label for node in filtered_nodes }
445
+
446
+ potential_schema = schema .potential_schema
447
+
448
+ for rel in extracted_relationships :
449
+ schema_relation = schema .relations .get (rel .type )
450
+ if not schema_relation :
451
+ continue
452
+
453
+ if (rel .start_node_id not in valid_nodes or
454
+ rel .end_node_id not in valid_nodes ):
455
+ continue
456
+
457
+ start_label = valid_nodes [rel .start_node_id ]
458
+ end_label = valid_nodes [rel .end_node_id ]
459
+
460
+ if (potential_schema and
461
+ (start_label , rel .type , end_label ) not in potential_schema ):
462
+ continue
463
+
464
+ allowed_props = schema_relation .get ("properties" , {})
465
+ filtered_props = self ._enforce_properties (rel .properties , allowed_props )
466
+
467
+ valid_rels .append (
468
+ Neo4jRelationship (
469
+ start_node_id = rel .start_node_id ,
470
+ end_node_id = rel .end_node_id ,
471
+ type = rel .type ,
472
+ properties = filtered_props ,
473
+ embedding_properties = rel .embedding_properties ,
474
+ )
475
+ )
461
476
462
477
return valid_rels
463
478
0 commit comments