diff --git a/docs/source/api.rst b/docs/source/api.rst index 55a5d1cc4..7ef687275 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -77,6 +77,12 @@ SchemaBuilder .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaBuilder :members: run +ConstraintProcessor +------------------- + +.. autoclass:: neo4j_graphrag.experimental.components.schema.ConstraintProcessor + :members: _process_constraints_against_schema + SchemaFromTextExtractor ----------------------- @@ -475,6 +481,8 @@ Errors * :class:`neo4j_graphrag.exceptions.SchemaValidationError` + * :class:`neo4j_graphrag.exceptions.SchemaDatabaseConflictError` + * :class:`neo4j_graphrag.exceptions.PdfLoaderError` * :class:`neo4j_graphrag.exceptions.PromptMissingPlaceholderError` @@ -604,6 +612,13 @@ SchemaValidationError :show-inheritance: +SchemaDatabaseConflictError +========================= + +.. autoclass:: neo4j_graphrag.exceptions.SchemaDatabaseConflictError + :show-inheritance: + + PdfLoaderError ============== diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 0a299a0a3..be63e968d 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -747,6 +747,46 @@ Optionally, the document and chunk node labels can be configured using a `Lexica Schema Builder ============== +Schema Components Overview +-------------------------- + +The Neo4j GraphRAG library provides two main components for working with graph schemas: + +- **SchemaBuilder**: For building schemas from manually defined node and relationship types +- **SchemaFromTextExtractor**: For automatically extracting schemas from text using LLMs + +Both components now share a unified constraint processing architecture that ensures +consistent behavior when working with Neo4j database constraints. + +**Unified Architecture Benefits:** + +- **Consistency**: Same validation and enhancement logic across all schema components +- **Flexibility**: Choose between strict validation or automatic enhancement +- **Maintainability**: Single source of truth for constraint processing +- **Reliability**: Comprehensive testing ensures robust constraint handling + +**Processing Modes:** + +.. list-table:: Schema Processing Modes + :header-rows: 1 + :widths: 30 35 35 + + * - Mode + - Behavior + - Best For + * - **Validation** (SchemaBuilder default) + - Raises errors for constraint conflicts + - Production schemas, explicit control + * - **Enhancement** (SchemaBuilder optional) + - Automatically resolves constraint conflicts + - Development, flexible user schemas + * - **Enhancement** (SchemaFromTextExtractor always) + - Automatically resolves constraint conflicts + - LLM-generated schemas + +Schema Definition +----------------- + The schema is used to try and ground the LLM to a list of possible node and relationship types of interest. So far, schema must be manually created by specifying: @@ -772,16 +812,16 @@ Here is a code block illustrating these concepts: NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), ], @@ -814,6 +854,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto # Instantiate the automatic schema extractor component schema_extractor = SchemaFromTextExtractor( + driver=neo4j_driver, # Add driver for constraint processing llm=OpenAILLM( model_name="gpt-4o", model_params={ @@ -841,6 +882,239 @@ You can also save and reload the extracted schema: restored_schema = GraphSchema.from_file("my_schema.json") # or my_schema.yaml +Schema-Database Constraint Validation +===================================== + +When using `SchemaBuilder` with an existing Neo4j database that contains constraints, +the component automatically processes your schema against database constraints. The +component supports two modes of operation: + +- **Validation Mode** (default): Validates schema compatibility and raises errors for conflicts +- **Enhancement Mode**: Automatically modifies schema to resolve conflicts + +.. note:: + + Both `SchemaBuilder` and `SchemaFromTextExtractor` now share the same underlying + constraint processing logic, ensuring consistent behavior across components. + +SchemaBuilder Modes +------------------- + +**Validation Mode (enhancement_mode=False)** + +This is the default mode that validates your user-defined schema against database +constraints and raises explicit errors if conflicts are detected: + +.. code:: python + + from neo4j_graphrag.experimental.components.schema import SchemaBuilder + + # Default validation mode + schema_builder = SchemaBuilder(driver=neo4j_driver, enhancement_mode=False) + + # Will raise SchemaDatabaseConflictError if conflicts are found + schema = await schema_builder.run(node_types=[...]) + +**Enhancement Mode (enhancement_mode=True)** + +This mode automatically modifies your schema to resolve conflicts with database constraints: + +.. code:: python + + from neo4j_graphrag.experimental.components.schema import SchemaBuilder + + # Enhancement mode - automatically fixes conflicts + schema_builder = SchemaBuilder(driver=neo4j_driver, enhancement_mode=True) + + # Will automatically add missing properties and adjust types + schema = await schema_builder.run(node_types=[...]) + +**When to Use Each Mode:** + +- **Validation Mode**: Use when you want explicit control over your schema and prefer + to manually resolve conflicts. Ideal for production schemas where changes should be deliberate. +- **Enhancement Mode**: Use during development or when you want automatic compatibility + with database constraints. Similar to `SchemaFromTextExtractor` behavior. + +Validation Rules +---------------- + +Both modes validate your schema against the following types of database constraints: + +**1. Missing Property Conflicts** + +If your schema defines an entity type (node or relationship) but omits properties that +are required by database existence constraints: + +.. code:: python + + # Database has constraint: CREATE CONSTRAINT FOR (p:Person) REQUIRE p.email IS NOT NULL + # But your schema doesn't include the 'email' property + + schema_builder = SchemaBuilder(driver=neo4j_driver, enhancement_mode=False) + + # Validation mode: This will raise SchemaDatabaseConflictError + await schema_builder.run( + node_types=[ + NodeType( + label="Person", + properties=[ + PropertyType(name="name", type="STRING") + # Missing required 'email' property! + ] + ) + ] + ) + + # Enhancement mode: This will automatically add the 'email' property + enhancement_builder = SchemaBuilder(driver=neo4j_driver, enhancement_mode=True) + enhanced_schema = await enhancement_builder.run(node_types=[...]) + + # The 'email' property will be automatically added as required + person_node = enhanced_schema.node_type_from_label("Person") + email_prop = person_node.get_property_by_name("email") + assert email_prop.required == True + +**Error Resolution (Validation Mode):** Add the missing properties to your schema or remove the database constraint. + +**2. Property Type Conflicts** + +If your schema defines property types that conflict with database type constraints: + +.. code:: python + + # Database has constraint: CREATE CONSTRAINT FOR (p:Person) REQUIRE p.age IS :: INTEGER + # But your schema defines 'age' as STRING + + # Validation mode: This will raise SchemaDatabaseConflictError + await schema_builder.run( + node_types=[ + NodeType( + label="Person", + properties=[ + PropertyType(name="age", type="STRING") # Conflicts with INTEGER constraint + ] + ) + ] + ) + + # Enhancement mode: This will automatically update the type to INTEGER + enhanced_schema = await enhancement_builder.run(node_types=[...]) + person_node = enhanced_schema.node_type_from_label("Person") + age_prop = person_node.get_property_by_name("age") + assert age_prop.type == "INTEGER" # Automatically corrected + +**Error Resolution (Validation Mode):** Update property types to match database constraints or remove the database constraint. + +**3. Missing Entity Type Conflicts** + +If database constraints reference entity types not defined in your schema and you've +disabled additional types: + +.. code:: python + + # Database has constraints on 'Company' nodes + # But your schema doesn't include Company and additional_node_types=False + + # Validation mode: This will raise SchemaDatabaseConflictError + await schema_builder.run( + node_types=[NodeType(label="Person")], + additional_node_types=False # Strict mode + ) + + # Enhancement mode: This will automatically add the Company node type + enhanced_schema = await enhancement_builder.run( + node_types=[NodeType(label="Person")], + additional_node_types=True # Allow automatic additions + ) + company_node = enhanced_schema.node_type_from_label("Company") + assert company_node is not None # Automatically added + +**Error Resolution (Validation Mode):** Add the missing entity types to your schema or set ``additional_node_types=True``. + +**4. Additional Properties Conflicts** + +If your entity has ``additional_properties=False`` but database constraints require +properties not in your schema: + +.. code:: python + + # Database requires 'email' property via existence constraint + # But your schema has additional_properties=False and doesn't include 'email' + + # Validation mode: This will raise SchemaDatabaseConflictError + await schema_builder.run( + node_types=[ + NodeType( + label="Person", + properties=[PropertyType(name="name", type="STRING")], + additional_properties=False # Strict mode, but missing required 'email' + ) + ] + ) + + # Enhancement mode: Respects additional_properties=False, won't add properties + # Use additional_properties=True to allow automatic property additions + +**Error Resolution (Validation Mode):** Add missing properties to your schema or set ``additional_properties=True``. + +Schema Enhancement Behavior +--------------------------- + +In enhancement mode, the `SchemaBuilder` performs the following automatic modifications: + +1. **Adds Missing Properties**: Creates properties required by database constraints +2. **Updates Property Types**: Adjusts types to match database type constraints +3. **Sets Required Properties**: Marks properties as required for existence constraints +4. **Adds Missing Entity Types**: Creates missing node/relationship types (if allowed) +5. **Respects User Constraints**: Won't add properties if ``additional_properties=False`` + +All enhanced properties include descriptive text indicating they were added due to database constraints. + +Error Handling +-------------- + +In validation mode, all constraint conflicts raise ``SchemaDatabaseConflictError`` with detailed error +messages explaining the conflict and suggesting resolutions: + +.. code:: python + + from neo4j_graphrag.exceptions import SchemaDatabaseConflictError + + try: + schema = await schema_builder.run(node_types=[...]) + except SchemaDatabaseConflictError as e: + print(f"Schema conflict detected: {e}") + # Error message will indicate exactly which properties or types are missing + # and provide suggestions for resolution + +Best Practices +-------------- + +1. **Choose the Right Mode:** + - Use validation mode for production schemas where explicit control is important + - Use enhancement mode for development or when you want automatic compatibility + +2. **Review Database Constraints:** Before defining your schema, review existing + database constraints using: + + .. code:: cypher + + SHOW CONSTRAINTS + +3. **Start Permissive:** Begin with ``additional_node_types=True`` and + ``additional_properties=True`` to allow flexibility during development. + +4. **Iterative Refinement:** In validation mode, use the error messages to iteratively + refine your schema until it's compatible with database constraints. + +5. **Constraint Alignment:** Ensure your schema property types match database + type constraints to avoid conflicts. + +6. **Required Properties:** Include all properties referenced by database existence + constraints in your schema definitions. + + Entity and Relation Extractor ============================= @@ -1169,3 +1443,119 @@ previously created document node: .. code:: python filter_query = "WHERE NOT EXISTS((entity)-[:FROM_DOCUMENT]->(:OldDocument))" + +.. note:: + + The `SchemaBuilder` validates user schemas against database constraints and raises + explicit errors for conflicts. For automatic schema generation from text, see + `SchemaFromTextExtractor Enhancement`_ below. + +**Error Examples:** + +.. code:: python + + # This will raise SchemaDatabaseConflictError with detailed message: + # "Database constraint NODE_PROPERTY_EXISTENCE on Person requires properties + # ['email'] that are not defined in user schema. Please add these properties + # to your Person definition or remove the constraint from the database." + + +SchemaFromTextExtractor Enhancement +=================================== + +The `SchemaFromTextExtractor` component automatically **enhances** LLM-generated schemas +to match database constraints. It uses the same underlying constraint processing logic +as `SchemaBuilder` in enhancement mode, but is specifically designed for schemas extracted +from text by Large Language Models. + +.. note:: + + `SchemaFromTextExtractor` always operates in enhancement mode - it never raises + errors for constraint conflicts. Instead, it automatically modifies the schema + to ensure database compatibility. + +**Shared Enhancement Logic:** + +Both `SchemaFromTextExtractor` and `SchemaBuilder` (in enhancement mode) now use the +same constraint processing engine, ensuring consistent behavior: + +1. **Adds Missing Properties**: If database constraints require properties not generated by the LLM +2. **Updates Property Types**: Adjusts property types to match database type constraints +3. **Sets Required Properties**: Marks properties as required when database has existence constraints +4. **Adds Missing Entity Types**: Creates entity types required by constraints (if `additional_*_types=True`) +5. **Respects User Constraints**: Won't add properties if `additional_properties=False` +6. **Graceful Failure**: Returns original schema if enhancement fails + +**Example:** + +.. code:: python + + from neo4j_graphrag.experimental.components.schema import SchemaFromTextExtractor + from your_llm_provider import YourLLM + + # Database has constraint: CREATE CONSTRAINT FOR (p:Person) REQUIRE p.email IS NOT NULL + + extractor = SchemaFromTextExtractor( + driver=neo4j_driver, + llm=YourLLM() + ) + + # LLM generates basic schema from text + text = "John works at Acme Corp. His email is john@acme.com" + + # Enhancement automatically adds missing 'email' property as required + enhanced_schema = await extractor.run(text) + + person_node = enhanced_schema.node_type_from_label("Person") + email_prop = person_node.get_property_by_name("email") + + assert email_prop.required == True # Enhanced due to database constraint + assert "constraint" in email_prop.description.lower() + +**Component Comparison:** + +.. list-table:: Schema Processing Approaches + :header-rows: 1 + :widths: 25 35 40 + + * - Component + - Default Behavior + - Use Case + * - ``SchemaBuilder`` (validation mode) + - Validates user schemas → Raises errors for conflicts + - Production schemas requiring explicit control + * - ``SchemaBuilder`` (enhancement mode) + - Enhances user schemas → Modifies automatically + - Development or flexible user schemas + * - ``SchemaFromTextExtractor`` + - Enhances LLM schemas → Always modifies automatically + - LLM-generated schemas needing database compatibility + +**Configuration Options:** + +The `SchemaFromTextExtractor` respects the same configuration options as `SchemaBuilder`: + +.. code:: python + + # Control what types of enhancements are allowed + enhanced_schema = await extractor.run( + text="Your text here", + additional_node_types=True, # Allow adding missing node types + additional_relationship_types=True, # Allow adding missing relationship types + ) + + # For entities with additional_properties=False, no properties will be added + # Use additional_properties=True (default) to allow property additions + +**Architecture Benefits:** + +The shared constraint processing architecture provides: + +- **Consistency**: Same validation logic across all schema components +- **Maintainability**: Single source of truth for constraint handling +- **Flexibility**: Easy to switch between validation and enhancement modes +- **Reliability**: Comprehensive testing covers all constraint scenarios + +This unified approach ensures that whether you're using manually defined schemas with +`SchemaBuilder` or automatically extracted schemas with `SchemaFromTextExtractor`, +you get consistent and reliable constraint processing behavior. diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py index adf8694a1..40a0f879c 100644 --- a/examples/customize/build_graph/components/pruners/graph_pruner.py +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -3,16 +3,15 @@ import asyncio from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning -from neo4j_graphrag.experimental.components.schema import ( +from neo4j_graphrag.experimental.components.types import ( GraphSchema, NodeType, PropertyType, RelationshipType, -) -from neo4j_graphrag.experimental.components.types import ( Neo4jGraph, Neo4jNode, Neo4jRelationship, + Neo4jPropertyType, ) graph = Neo4jGraph( @@ -68,24 +67,30 @@ NodeType( label="Person", properties=[ - PropertyType(name="firstName", type="STRING", required=True), - PropertyType(name="lastName", type="STRING", required=True), - PropertyType(name="age", type="INTEGER"), + PropertyType( + name="firstName", type=Neo4jPropertyType.STRING, required=True + ), + PropertyType( + name="lastName", type=Neo4jPropertyType.STRING, required=True + ), + PropertyType(name="age", type=Neo4jPropertyType.INTEGER), ], additional_properties=False, ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING", required=True), - PropertyType(name="address", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True), + PropertyType(name="address", type=Neo4jPropertyType.STRING), ], ), ), relationship_types=( RelationshipType( label="WORKS_FOR", - properties=[PropertyType(name="since", type="LOCAL_DATETIME")], + properties=[ + PropertyType(name="since", type=Neo4jPropertyType.LOCAL_DATETIME) + ], ), RelationshipType( label="KNOWS", diff --git a/examples/customize/build_graph/components/schema_builders/schema.py b/examples/customize/build_graph/components/schema_builders/schema.py index 6ca408dee..38dae9475 100644 --- a/examples/customize/build_graph/components/schema_builders/schema.py +++ b/examples/customize/build_graph/components/schema_builders/schema.py @@ -12,46 +12,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + +import neo4j + from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - NodeType, - PropertyType, - RelationshipType, ) async def main() -> None: - schema_builder = SchemaBuilder() + with neo4j.GraphDatabase.driver( + "bolt://localhost:7687", + auth=("neo4j", "password"), + ) as driver: + schema_builder = SchemaBuilder(driver) + + schema = await schema_builder.run( + node_types=[ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "place_of_birth", "type": "STRING"}, + {"name": "date_of_birth", "type": "DATE"}, + ], + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "country", "type": "STRING"}, + ], + }, + { + "label": "Field", + "properties": [ + {"name": "name", "type": "STRING"}, + ], + }, + ], + relationship_types=[ + "WORKED_ON", + { + "label": "WORKED_FOR", + }, + ], + patterns=[ + ("Person", "WORKED_ON", "Field"), + ("Person", "WORKED_FOR", "Organization"), + ], + ) + print(schema) + - result = await schema_builder.run( - node_types=[ - NodeType( - label="Person", - properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), - ], - ), - NodeType( - label="Organization", - properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="country", type="STRING"), - ], - ), - ], - relationship_types=[ - RelationshipType( - label="WORKED_ON", - ), - RelationshipType( - label="WORKED_FOR", - ), - ], - patterns=[ - ("Person", "WORKED_ON", "Field"), - ("Person", "WORKED_FOR", "Organization"), - ], - ) - print(result) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index 947907dff..753819f7c 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -14,6 +14,8 @@ from neo4j_graphrag.experimental.components.schema import ( SchemaFromTextExtractor, +) +from neo4j_graphrag.experimental.components.types import ( GraphSchema, ) from neo4j_graphrag.llm import OpenAILLM @@ -74,7 +76,7 @@ async def extract_and_save_schema() -> None: try: # Create a SchemaFromTextExtractor component with the default template - schema_extractor = SchemaFromTextExtractor(llm=llm) + schema_extractor = SchemaFromTextExtractor(driver=None, llm=llm) print("Extracting schema from text...") # Extract schema from text diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index ea727fe3c..4bd480d5b 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -25,6 +25,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, RelationshipType, ) @@ -87,7 +89,7 @@ async def define_and_run_pipeline( FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False), "splitter", ) - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 3a9e30911..777cad3ea 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -25,9 +25,12 @@ from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, RelationshipType, + Neo4jPropertyType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -63,7 +66,7 @@ async def define_and_run_pipeline( "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, @@ -99,22 +102,24 @@ async def define_and_run_pipeline( NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType( + name="place_of_birth", type=Neo4jPropertyType.STRING + ), + PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE), ], ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="country", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="country", type=Neo4jPropertyType.STRING), ], ), NodeType( label="Field", properties=[ - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), ], ), ], diff --git a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py index eda2b4219..490f836e2 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py +++ b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py @@ -28,9 +28,12 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, RelationshipType, + Neo4jPropertyType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -61,7 +64,7 @@ async def define_and_run_pipeline( FixedSizeSplitter(), "splitter", ) - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, @@ -96,16 +99,18 @@ async def define_and_run_pipeline( NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType( + name="place_of_birth", type=Neo4jPropertyType.STRING + ), + PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE), ], ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="country", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="country", type=Neo4jPropertyType.STRING), ], ), ], diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index daaab51a5..1c252ba68 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -16,9 +16,11 @@ from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, - RelationshipType, + RelationshipType, Neo4jPropertyType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -66,7 +68,7 @@ async def define_and_run_pipeline( "lexical_graph_builder", ) pipe.add_component(Neo4jWriter(neo4j_driver), "lg_writer") - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, @@ -122,22 +124,22 @@ async def define_and_run_pipeline( NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="place_of_birth", type=Neo4jPropertyType.STRING), + PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE), ], ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="country", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="country", type=Neo4jPropertyType.STRING), ], ), NodeType( label="Field", properties=[ - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), ], ), ], diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index b5b6b5273..0aed24b39 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -18,9 +18,11 @@ from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, - RelationshipType, + RelationshipType, Neo4jPropertyType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -112,7 +114,7 @@ async def read_chunk_and_perform_entity_extraction( pipe = Pipeline() # define the components pipe.add_component(Neo4jChunkReader(neo4j_driver), "reader") - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, @@ -142,22 +144,22 @@ async def read_chunk_and_perform_entity_extraction( NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="place_of_birth", type=Neo4jPropertyType.STRING), + PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE), ], ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="country", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="country", type=Neo4jPropertyType.STRING), ], ), NodeType( label="Field", properties=[ - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), ], ), ], diff --git a/examples/kg_builder.py b/examples/kg_builder.py index c98f0c069..9d5ba25cb 100644 --- a/examples/kg_builder.py +++ b/examples/kg_builder.py @@ -32,6 +32,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, RelationshipType, ) @@ -91,7 +93,7 @@ async def define_and_run_pipeline( pipe.add_component( FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter" ) - pipe.add_component(SchemaBuilder(), "schema") + pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema") pipe.add_component( LLMEntityRelationExtractor( llm=llm, diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 681b20eec..d8b539ca3 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -116,6 +116,12 @@ class SchemaValidationError(Neo4jGraphRagError): pass +class SchemaDatabaseConflictError(SchemaValidationError): + """Exception raised when user schema conflicts with database constraints.""" + + pass + + class SchemaExtractionError(Neo4jGraphRagError): """Exception raised for errors in automatic schema extraction.""" diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index b1ed9bbb1..995505a75 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -25,10 +25,10 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, + GraphSchema, Neo4jGraph, TextChunk, TextChunks, diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index d49f48ce0..31e90e783 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -18,7 +18,7 @@ from pydantic import validate_call, BaseModel -from neo4j_graphrag.experimental.components.schema import ( +from neo4j_graphrag.experimental.components.types import ( GraphSchema, PropertyType, NodeType, diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 7e00eb586..bf5eb7405 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -14,325 +14,937 @@ # limitations under the License. from __future__ import annotations +import copy import json import logging -import warnings -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence -from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Sequence +import neo4j from pydantic import ( - BaseModel, - PrivateAttr, - model_validator, validate_call, - ConfigDict, ValidationError, ) -from typing_extensions import Self from neo4j_graphrag.exceptions import ( SchemaValidationError, + SchemaDatabaseConflictError, LLMGenerationError, SchemaExtractionError, ) -from neo4j_graphrag.experimental.pipeline.component import Component, DataModel +from neo4j_graphrag.experimental.pipeline.component import Component from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface -from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat +from neo4j_graphrag.experimental.components.types import ( + GraphSchema, + SchemaConstraint, + Neo4jConstraintTypeEnum, + GraphEntityType, + Neo4jPropertyType, + NodeType, + PropertyType, + RelationshipType, +) +from neo4j_graphrag.schema import get_constraints -class PropertyType(BaseModel): +class ConstraintProcessor: """ - Represents a property on a node or relationship in the graph. + Base class for processing database constraints against schemas. + + Provides shared logic for both validation and enhancement modes when working + with Neo4j database constraints. This class handles the core constraint processing + functionality used by both SchemaBuilder and SchemaFromTextExtractor. + + The constraint processor can operate in two modes: + - **Validation Mode**: Validates schemas against constraints and raises errors for conflicts + - **Enhancement Mode**: Automatically modifies schemas to resolve constraint conflicts + + Args: + driver: Neo4j driver instance for database access + neo4j_database: Optional Neo4j database name. If None, uses default database. + + Attributes: + driver: The Neo4j driver instance + neo4j_database: The Neo4j database name """ + + def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None): + """ + Initialize the ConstraintProcessor. + + Args: + driver: Neo4j driver instance for database access + neo4j_database: Optional Neo4j database name. If None, uses default database. + """ + self.driver = driver + self.neo4j_database = neo4j_database - name: str - # See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types - type: Literal[ - "BOOLEAN", - "DATE", - "DURATION", - "FLOAT", - "INTEGER", - "LIST", - "LOCAL_DATETIME", - "LOCAL_TIME", - "POINT", - "STRING", - "ZONED_DATETIME", - "ZONED_TIME", - ] - description: str = "" - required: bool = False - - model_config = ConfigDict( - frozen=True, - ) - - -class NodeType(BaseModel): - """ - Represents a possible node in the graph. - """ + def _get_constraints_from_db(self) -> list[SchemaConstraint]: + """ + Retrieve all constraints from the Neo4j database. + + Returns: + List of SchemaConstraint objects representing database constraints. + + Note: + This method uses the get_constraints utility function to fetch constraints + from the database and converts them to SchemaConstraint objects. + """ + constraints = get_constraints( + self.driver, database=self.neo4j_database, sanitize=False + ) + return [ + SchemaConstraint.model_validate(c) + for c in constraints + ] - label: str - description: str = "" - properties: list[PropertyType] = [] - additional_properties: bool = True - - @model_validator(mode="before") - @classmethod - def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: - if isinstance(data, str): - return {"label": data} - return data - - @model_validator(mode="after") - def validate_additional_properties(self) -> Self: - if len(self.properties) == 0 and not self.additional_properties: - raise ValueError( - "Using `additional_properties=False` with no defined " - "properties will cause the model to be pruned during graph cleaning.", + @staticmethod + def _parse_property_type(property_type: str) -> list[Neo4jPropertyType]: + """ + Parse a property type string into a list of Neo4jPropertyType enums. + + Args: + property_type: String representation of property types (e.g., "STRING|INTEGER") + + Returns: + List of Neo4jPropertyType enums. Empty list if property_type is empty or invalid. + + Note: + This method handles pipe-separated type strings and ignores invalid types. + """ + prop_types = [] + if not property_type: + return prop_types + for prop_str in property_type.split("|"): + p = prop_str.strip() + try: + prop = Neo4jPropertyType(p) + prop_types.append(prop) + except ValueError: + pass + return prop_types + + def _infer_property_type_from_constraint( + self, + constraint: SchemaConstraint + ) -> Neo4jPropertyType: + """ + Infer the best property type from a database constraint. + + Args: + constraint: The database constraint to analyze + + Returns: + Neo4jPropertyType: The inferred property type. Defaults to STRING if no + specific type is defined in the constraint. + + Note: + For type constraints, returns the first allowed type. For existence + constraints without type information, defaults to STRING. + """ + if constraint.property_type: + # Use the first type from the constraint + return constraint.property_type[0] + else: + # Default to STRING for existence constraints + return Neo4jPropertyType.STRING + + def _check_missing_properties( + self, + entity_type: GraphEntityType, + constraint: SchemaConstraint + ) -> list[str]: + """ + Check for properties required by constraint but missing from entity. + + Args: + entity_type: The entity type to check + constraint: The database constraint to validate against + + Returns: + List of property names that are required by the constraint but missing + from the entity definition. + """ + missing_props = [] + for prop_name in constraint.properties: + if entity_type.get_property_by_name(prop_name) is None: + missing_props.append(prop_name) + return missing_props + + def _check_property_type_compatibility( + self, + entity_type: GraphEntityType, + constraint: SchemaConstraint + ) -> list[tuple[str, list[Neo4jPropertyType], list[Neo4jPropertyType]]]: + """ + Check for property type conflicts between entity and constraint. + + Args: + entity_type: The entity type to check + constraint: The database constraint to validate against + + Returns: + List of tuples containing (property_name, user_types, database_allowed_types) + for each property that has a type conflict. + + Note: + Only checks type constraints. Returns empty list for constraints without + type information. + """ + conflicts = [] + if not constraint.property_type: + return conflicts + + for prop_name in constraint.properties: + user_prop = entity_type.get_property_by_name(prop_name) + if user_prop: + user_types = user_prop.type if isinstance(user_prop.type, list) else [user_prop.type] + db_allowed_types = constraint.property_type + + # Check if any user type is allowed by DB + if not any(ut in db_allowed_types for ut in user_types): + conflicts.append((prop_name, user_types, db_allowed_types)) + return conflicts + + def _check_missing_entity_types( + self, + constraints: list[SchemaConstraint], + user_node_labels: set[str], + user_rel_types: set[str], + additional_node_types: bool, + additional_relationship_types: bool + ) -> tuple[set[str], set[str]]: + """ + Check for entity types required by constraints but missing from schema. + + Args: + constraints: List of database constraints + user_node_labels: Set of node labels defined in user schema + user_rel_types: Set of relationship types defined in user schema + additional_node_types: Whether additional node types are allowed + additional_relationship_types: Whether additional relationship types are allowed + + Returns: + Tuple of (missing_node_labels, missing_relationship_types) that are + required by constraints but not defined in the schema and not allowed + by the additional_*_types flags. + """ + missing_node_labels = set() + missing_rel_types = set() + + for constraint in constraints: + if constraint.entity_type == "NODE" and not additional_node_types: + missing_labels = set(constraint.label_or_type) - user_node_labels + missing_node_labels.update(missing_labels) + elif constraint.entity_type == "RELATIONSHIP" and not additional_relationship_types: + missing_types = set(constraint.label_or_type) - user_rel_types + missing_rel_types.update(missing_types) + + return missing_node_labels, missing_rel_types + + def _check_additional_properties_conflicts( + self, + entity_type: GraphEntityType, + constraints: list[SchemaConstraint] + ) -> set[str]: + """ + Check for conflicts when entity has additional_properties=False. + + Args: + entity_type: The entity type to check + constraints: List of database constraints + + Returns: + Set of property names that are required by database constraints but + missing from the entity schema when additional_properties=False. + + Note: + Returns empty set if additional_properties=True, as no conflict exists + in that case. + """ + if entity_type.additional_properties: + return set() # No conflict if additional properties are allowed + + # Find all properties required by DB constraints for this entity + required_by_db = set() + for constraint in constraints: + if (constraint.entity_type == entity_type.entity_type_name and + constraint.label_or_type[0] == entity_type.label): + required_by_db.update(constraint.properties) + + # Check if any DB-required properties are missing from user schema + user_properties = {prop.name for prop in entity_type.properties} + missing_required = required_by_db - user_properties + + return missing_required + + def _process_constraints_against_schema( + self, + schema: GraphSchema, + mode: str = "validate", # "validate" or "enhance" + **kwargs: Any + ) -> GraphSchema: + """ + Process database constraints against schema in either validation or enhancement mode. + + This is the main constraint processing method that handles both validation and + enhancement modes. It coordinates all constraint checking and resolution logic. + + Args: + schema: The schema to process + mode: Processing mode - "validate" to raise errors, "enhance" to modify schema + **kwargs: Additional configuration parameters including: + - additional_node_types (bool): Whether to allow additional node types + - additional_relationship_types (bool): Whether to allow additional relationship types + + Returns: + GraphSchema: The processed schema. In validation mode, returns the original + schema with safe enhancements (like required=True). In enhancement mode, + returns a modified schema with added properties and entity types. + + Raises: + SchemaDatabaseConflictError: If mode="validate" and conflicts are found + ValueError: If mode is not "validate" or "enhance" + + Note: + This method is the core of the constraint processing system and is used + by both SchemaBuilder and SchemaFromTextExtractor. + """ + constraints = self._get_constraints_from_db() + if not constraints: + return schema # No constraints to process + + # Get configuration flags + additional_node_types = kwargs.get('additional_node_types', True) + additional_relationship_types = kwargs.get('additional_relationship_types', True) + + # Check for conflicts + user_node_labels = {node.label for node in schema.node_types} + user_rel_types = {rel.label for rel in schema.relationship_types} + + # 1. Check missing entity types + missing_node_labels, missing_rel_types = self._check_missing_entity_types( + constraints, user_node_labels, user_rel_types, + additional_node_types, additional_relationship_types + ) + + if mode == "validate": + # Validation mode: raise errors for conflicts + self._validate_missing_entity_types(missing_node_labels, missing_rel_types) + self._validate_entity_constraints(schema.node_types + schema.relationship_types, constraints) + + # Apply only safe enhancements (required=True) + enhanced_entities = self._apply_safe_enhancements( + list(schema.node_types) + list(schema.relationship_types), constraints ) - return self - - -class RelationshipType(BaseModel): - """ - Represents a possible relationship between nodes in the graph. - """ - - label: str - description: str = "" - properties: list[PropertyType] = [] - additional_properties: bool = True - - @model_validator(mode="before") - @classmethod - def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: - if isinstance(data, str): - return {"label": data} - return data - - @model_validator(mode="after") - def validate_additional_properties(self) -> Self: - if len(self.properties) == 0 and not self.additional_properties: - raise ValueError( - "Using `additional_properties=False` with no defined " - "properties will cause the model to be pruned during graph cleaning.", + + return self._rebuild_schema_with_entities(schema, enhanced_entities, **kwargs) + + elif mode == "enhance": + # Enhancement mode: modify schema to resolve conflicts + return self._enhance_schema_with_constraints(schema, constraints, **kwargs) + + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'validate' or 'enhance'") + + def _validate_missing_entity_types(self, missing_node_labels: set[str], missing_rel_types: set[str]) -> None: + """ + Raise errors for missing entity types in validation mode. + + Args: + missing_node_labels: Set of node labels required by constraints but missing from schema + missing_rel_types: Set of relationship types required by constraints but missing from schema + + Raises: + SchemaDatabaseConflictError: If any missing entity types are found + """ + if missing_node_labels: + raise SchemaDatabaseConflictError( + f"Database has constraints on node labels {missing_node_labels} " + f"that are not defined in user schema. Please add these node types " + f"or set additional_node_types=True." + ) + if missing_rel_types: + raise SchemaDatabaseConflictError( + f"Database has constraints on relationship types {missing_rel_types} " + f"that are not defined in user schema. Please add these relationship types " + f"or set additional_relationship_types=True." ) - return self - - -class GraphSchema(DataModel): - """This model represents the expected - node and relationship types in the graph. - - It is used both for guiding the LLM in the entity and relation - extraction component, and for cleaning the extracted graph in a - post-processing step. - - .. warning:: - - This model is immutable. - """ - - node_types: Tuple[NodeType, ...] - relationship_types: Tuple[RelationshipType, ...] = tuple() - patterns: Tuple[Tuple[str, str, str], ...] = tuple() - - additional_node_types: bool = True - additional_relationship_types: bool = True - additional_patterns: bool = True - - _node_type_index: dict[str, NodeType] = PrivateAttr() - _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() - - model_config = ConfigDict( - frozen=True, - ) - - @model_validator(mode="after") - def validate_patterns_against_node_and_rel_types(self) -> Self: - self._node_type_index = {node.label: node for node in self.node_types} - self._relationship_type_index = ( - {r.label: r for r in self.relationship_types} - if self.relationship_types - else {} - ) - - relationship_types = self.relationship_types - patterns = self.patterns - if patterns: - if not relationship_types: - raise SchemaValidationError( - "Relationship types must also be provided when using patterns." + def _validate_entity_constraints( + self, + entities: list[GraphEntityType], + constraints: list[SchemaConstraint] + ) -> None: + """ + Validate all entity constraints in validation mode. + + Args: + entities: List of entities (nodes and relationships) to validate + constraints: List of database constraints to validate against + + Raises: + SchemaDatabaseConflictError: If any constraint conflicts are found + + Note: + This method performs comprehensive validation including missing properties, + property type conflicts, and additional_properties conflicts. + """ + for entity in entities: + relevant_constraints = [ + c for c in constraints + if (c.entity_type == entity.entity_type_name and + entity.label in c.label_or_type) + ] + + # Check additional properties conflicts first (more specific error) + missing_required = self._check_additional_properties_conflicts(entity, relevant_constraints) + if missing_required: + raise SchemaDatabaseConflictError( + f"{entity.label} has additional_properties=False but database " + f"constraints require properties {missing_required} not in user schema. " + f"Please add these properties or set additional_properties=True." ) - for entity1, relation, entity2 in patterns: - if entity1 not in self._node_type_index: - raise SchemaValidationError( - f"Node type '{entity1}' is not defined in the provided node_types." - ) - if relation not in self._relationship_type_index: - raise SchemaValidationError( - f"Relationship type '{relation}' is not defined in the provided relationship_types." - ) - if entity2 not in self._node_type_index: - raise ValueError( - f"Node type '{entity2}' is not defined in the provided node_types." + + for constraint in relevant_constraints: + # Check missing properties + missing_props = self._check_missing_properties(entity, constraint) + if missing_props: + raise SchemaDatabaseConflictError( + f"Database constraint {constraint.type} on {entity.label} " + f"requires properties {missing_props} that are not defined in user schema. " + f"Please add these properties to your {entity.label} definition or " + f"remove the constraint from the database." ) - - return self - - @model_validator(mode="after") - def validate_additional_parameters(self) -> Self: - if ( - self.additional_patterns is False - and self.additional_relationship_types is True - ): - raise ValueError( - "`additional_relationship_types` must be set to False when using `additional_patterns=False`" + + # Check property type conflicts + type_conflicts = self._check_property_type_compatibility(entity, constraint) + if type_conflicts: + for prop_name, user_types, db_types in type_conflicts: + raise SchemaDatabaseConflictError( + f"Property '{prop_name}' on {entity.label} has type {user_types} " + f"in user schema, but database constraint allows only {db_types}. " + f"Please update the property type or remove the database constraint." + ) + + def _apply_safe_enhancements( + self, + entities: list[GraphEntityType], + constraints: list[SchemaConstraint] + ) -> list[GraphEntityType]: + """ + Apply only safe enhancements that don't add new properties. + + Safe enhancements include setting required=True for properties that already + exist in the schema but are required by database existence constraints. + + Args: + entities: List of entities to enhance + constraints: List of database constraints + + Returns: + List of enhanced entities with safe modifications applied + + Note: + This method is used in validation mode to apply non-conflicting + enhancements while preserving the user's original schema structure. + """ + enhanced_entities = [copy.deepcopy(entity) for entity in entities] + + for entity in enhanced_entities: + relevant_constraints = [ + c for c in constraints + if (c.entity_type == entity.entity_type_name and + entity.label in c.label_or_type) + ] + + for constraint in relevant_constraints: + if constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE, + ): + for prop_name in constraint.properties: + prop = entity.get_property_by_name(prop_name) + if prop and not prop.required: + prop.required = True + + return enhanced_entities + + def _enhance_schema_with_constraints( + self, + schema: GraphSchema, + constraints: list[SchemaConstraint], + **kwargs: Any + ) -> GraphSchema: + """ + Enhance schema by adding missing entities and properties in enhancement mode. + + This method performs comprehensive schema enhancement including: + - Adding missing entity types required by constraints + - Adding missing properties to existing entities + - Updating property types to match constraints + - Setting required flags for existence constraints + + Args: + schema: The original schema to enhance + constraints: List of database constraints to apply + **kwargs: Configuration parameters including additional_*_types flags + + Returns: + GraphSchema: Enhanced schema with all constraint conflicts resolved. + If enhancement fails, returns the original schema with a warning logged. + + Note: + This method is used by SchemaFromTextExtractor and SchemaBuilder in + enhancement mode to automatically resolve constraint conflicts. + """ + # Create mutable copies of entities + enhanced_node_types = [copy.deepcopy(node) for node in schema.node_types] + enhanced_relationship_types = [copy.deepcopy(rel) for rel in schema.relationship_types] + + # Get configuration flags + additional_node_types = kwargs.get('additional_node_types', True) + additional_relationship_types = kwargs.get('additional_relationship_types', True) + + # Step 1: Add missing entity types required by constraints + enhanced_node_types = self._add_missing_entity_types( + enhanced_node_types, constraints, "NODE", additional_node_types + ) + enhanced_relationship_types = self._add_missing_entity_types( + enhanced_relationship_types, constraints, "RELATIONSHIP", additional_relationship_types + ) + + # Step 2: Enhance existing entities with constraint requirements + all_enhanced_entities = enhanced_node_types + enhanced_relationship_types + for entity in all_enhanced_entities: + self._enhance_entity_with_constraints(entity, constraints) + + # Step 3: Create enhanced schema + try: + enhanced_schema = GraphSchema.model_validate( + dict( + node_types=enhanced_node_types, + relationship_types=enhanced_relationship_types, + patterns=schema.patterns, + **kwargs, + ) ) - return self - - def node_type_from_label(self, label: str) -> Optional[NodeType]: - return self._node_type_index.get(label) - - def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: - return self._relationship_type_index.get(label) + except ValidationError as e: + # If enhancement fails, log warning and return original schema + logging.warning(f"Failed to enhance schema with constraints: {e}") + return schema + + return enhanced_schema - def save( + def _add_missing_entity_types( self, - file_path: Union[str, Path], - overwrite: bool = False, - format: Optional[FileFormat] = None, - ) -> None: + existing_entities: list[GraphEntityType], + constraints: list[SchemaConstraint], + entity_type: str, + allow_additional: bool + ) -> list[GraphEntityType]: """ - Save the schema configuration to file. - + Add entity types that are required by database constraints but missing from schema. + + Creates minimal entity definitions for missing entity types that are referenced + by database constraints. Each created entity includes all properties required + by the relevant constraints. + Args: - file_path (str): The path where the schema configuration will be saved. - overwrite (bool): If set to True, existing file will be overwritten. Default to False. - format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. + existing_entities: Current list of entities in the schema + constraints: Database constraints to analyze + entity_type: Type of entities to add - "NODE" or "RELATIONSHIP" + allow_additional: Whether to add missing types (respects additional_*_types flags) + + Returns: + Enhanced list of entities with missing types added. If allow_additional is False, + returns the original list unchanged. + + Note: + Created entities have additional_properties=True to allow flexibility and + include descriptive text indicating they were added due to constraints. """ - data = self.model_dump(mode="json") - file_handler = FileHandler() - file_handler.write(data, file_path, overwrite=overwrite, format=format) + if not allow_additional: + return existing_entities + + existing_labels = {entity.label for entity in existing_entities} + + # Find labels referenced by constraints but not in schema + required_labels = set() + for constraint in constraints: + if constraint.entity_type == entity_type: + required_labels.update(constraint.label_or_type) + + missing_labels = required_labels - existing_labels + + # Create minimal entity definitions for missing labels + enhanced_entities = list(existing_entities) + for label in missing_labels: + # Create a basic entity with properties required by constraints + required_properties = [] + for constraint in constraints: + if (constraint.entity_type == entity_type and + label in constraint.label_or_type): + for prop_name in constraint.properties: + # Add property if not already added + if not any(p.name == prop_name for p in required_properties): + prop_type = self._infer_property_type_from_constraint(constraint) + required_properties.append( + PropertyType( + name=prop_name, + type=prop_type, + required=constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE + ), + description=f"Property required by database constraint" + ) + ) + + # Create the entity + if entity_type == "NODE": + new_entity = NodeType( + label=label, + description=f"Node type added to match database constraints", + properties=required_properties, + additional_properties=True # Allow flexibility + ) + else: # RELATIONSHIP + new_entity = RelationshipType( + label=label, + description=f"Relationship type added to match database constraints", + properties=required_properties, + additional_properties=True + ) + + enhanced_entities.append(new_entity) + logging.info(f"Added missing {entity_type.lower()} type '{label}' to match database constraints") + + return enhanced_entities - def store_as_json( - self, file_path: Union[str, Path], overwrite: bool = False + def _enhance_entity_with_constraints( + self, + entity: GraphEntityType, + constraints: list[SchemaConstraint] ) -> None: - warnings.warn( - "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning - ) - return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) - - def store_as_yaml( - self, file_path: Union[str, Path], overwrite: bool = False + """ + Enhance an entity by adding missing properties and updating existing ones. + + This method modifies the entity in-place to match database constraint requirements. + It adds missing properties (if additional_properties=True) and enhances existing + properties with constraint information. + + Args: + entity: The entity to enhance (modified in-place) + constraints: Database constraints to apply + + Note: + Only adds properties if entity.additional_properties is True. Always + enhances existing properties regardless of additional_properties setting. + """ + relevant_constraints = [ + c for c in constraints + if (c.entity_type == entity.entity_type_name and + entity.label in c.label_or_type) + ] + + if not relevant_constraints: + return + + for constraint in relevant_constraints: + for prop_name in constraint.properties: + existing_prop = entity.get_property_by_name(prop_name) + + if existing_prop: + # Enhance existing property + self._enhance_property_with_constraint(existing_prop, constraint) + else: + # Add missing property + if entity.additional_properties: + # Only add if additional properties are allowed + prop_type = self._infer_property_type_from_constraint(constraint) + new_prop = PropertyType( + name=prop_name, + type=prop_type, + required=constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE + ), + description=f"Property added to match database constraint" + ) + entity.properties.append(new_prop) + logging.info(f"Added missing property '{prop_name}' to {entity.label}") + + def _enhance_property_with_constraint( + self, + prop: PropertyType, + constraint: SchemaConstraint ) -> None: - warnings.warn( - "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning - ) - return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) - - @classmethod - def from_file( - cls, file_path: Union[str, Path], format: Optional[FileFormat] = None - ) -> Self: """ - Load a schema configuration from a file (either JSON or YAML). - - The file format is automatically detected based on the file extension, - unless the format parameter is set. - + Enhance a property to match database constraint requirements. + + This method modifies the property in-place to align with database constraints. + It can set required=True for existence constraints and update property types + for type constraints. + Args: - file_path (Union[str, Path]): The path to the schema configuration file. - format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). - + prop: The property to enhance (modified in-place) + constraint: The constraint to apply + + Note: + For type constraints, only updates the property type if the current type + is not compatible with the constraint. Uses the first allowed type from + the constraint. + """ + # Set required=True for existence constraints + if constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE + ): + if not prop.required: + prop.required = True + logging.info(f"Enhanced property '{prop.name}' to required=True due to database constraint") + + # Update property type for type constraints + if (constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_TYPE + ) and constraint.property_type): + current_types = prop.type if isinstance(prop.type, list) else [prop.type] + + # Check if current type is compatible with constraint + if not any(ct in constraint.property_type for ct in current_types): + # Use the first allowed type from constraint + new_type = constraint.property_type[0] + prop.type = new_type + logging.info(f"Enhanced property '{prop.name}' type to {new_type} due to database constraint") + + def _rebuild_schema_with_entities( + self, + original_schema: GraphSchema, + enhanced_entities: list[GraphEntityType], + **kwargs: Any + ) -> GraphSchema: + """ + Rebuild a GraphSchema with enhanced entities. + + Args: + original_schema: The original schema to preserve patterns and other settings + enhanced_entities: List of enhanced entities (both nodes and relationships) + **kwargs: Additional schema configuration parameters + Returns: - GraphSchema: The loaded schema configuration. + GraphSchema: New schema with enhanced entities + + Raises: + SchemaValidationError: If the enhanced schema fails validation """ - file_path = Path(file_path) - file_handler = FileHandler() - try: - data = file_handler.read(file_path, format=format) - except ValueError: - raise - + # Split entities back into nodes and relationships + enhanced_nodes = [e for e in enhanced_entities if e.entity_type_name == "NODE"] + enhanced_rels = [e for e in enhanced_entities if e.entity_type_name == "RELATIONSHIP"] + try: - return cls.model_validate(data) + return GraphSchema.model_validate( + dict( + node_types=enhanced_nodes, + relationship_types=enhanced_rels, + patterns=original_schema.patterns, + **kwargs, + ) + ) except ValidationError as e: - raise SchemaValidationError(str(e)) from e + raise SchemaValidationError("Error when applying constraints from database") from e -class SchemaBuilder(Component): +class SchemaBuilder(Component, ConstraintProcessor): """ - A builder class for constructing GraphSchema objects from given entities, - relations, and their interrelationships defined in a potential schema. + A builder class for constructing GraphSchema objects from manually defined entities, + relations, and their interrelationships. + + SchemaBuilder supports two modes of operation: + + - **Validation Mode** (default): Validates user schemas against database constraints + and raises SchemaDatabaseConflictError for conflicts. Use when you want explicit + control over your schema. + - **Enhancement Mode**: Automatically modifies schemas to resolve conflicts with + database constraints. Use during development or when you want automatic compatibility. + + Args: + driver: Neo4j driver instance for database access + neo4j_database: Optional Neo4j database name. If None, uses default database. + enhancement_mode: If True, enhances schema instead of raising errors. + If False (default), validates schema and raises errors for conflicts. Example: + .. code-block:: python - .. code-block:: python + from neo4j_graphrag.experimental.components.schema import ( + SchemaBuilder, + NodeType, + PropertyType, + RelationshipType, + ) - from neo4j_graphrag.experimental.components.schema import ( - SchemaBuilder, - NodeType, - PropertyType, - RelationshipType, - ) - from neo4j_graphrag.experimental.pipeline import Pipeline - - node_types = [ - NodeType( - label="PERSON", - description="An individual human being.", - properties=[ - PropertyType( - name="name", type="STRING", description="The name of the person" - ) - ], - ), - NodeType( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - properties=[ - PropertyType( - name="name", type="STRING", description="The name of the organization" - ) - ], - ), - ] - relationship_types = [ - RelationshipType( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - ] - patterns = [ - ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), - ] - pipe = Pipeline() - schema_builder = SchemaBuilder() - pipe.add_component(schema_builder, "schema_builder") - pipe_inputs = { - "schema": { - "node_types": node_types, - "relationship_types": relationship_types, - "patterns": patterns, - }, - ... - } - pipe.run(pipe_inputs) + # Validation mode (default) - raises errors for conflicts + schema_builder = SchemaBuilder(driver, enhancement_mode=False) + + # Enhancement mode - automatically fixes conflicts + schema_builder = SchemaBuilder(driver, enhancement_mode=True) + + node_types = [ + NodeType( + label="PERSON", + description="An individual human being.", + properties=[ + PropertyType( + name="name", type="STRING", description="The name of the person" + ) + ], + ), + NodeType( + label="ORGANIZATION", + description="A structured group of people with a common purpose.", + properties=[ + PropertyType( + name="name", type="STRING", description="The name of the organization" + ) + ], + ), + ] + relationship_types = [ + RelationshipType( + label="EMPLOYED_BY", description="Indicates employment relationship." + ), + ] + patterns = [ + ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), + ] + + # This will validate/enhance the schema against database constraints + schema = await schema_builder.run( + node_types=node_types, + relationship_types=relationship_types, + patterns=patterns, + ) """ + def __init__( + self, + driver: neo4j.Driver, + neo4j_database: Optional[str] = None, + enhancement_mode: bool = False + ) -> None: + """ + Initialize SchemaBuilder with constraint processing capabilities. + + Args: + driver: Neo4j driver instance for database access + neo4j_database: Optional Neo4j database name. If None, uses default database. + enhancement_mode: If True, enhances schema instead of raising errors. + If False (default), validates schema and raises errors for conflicts. + """ + super().__init__(driver, neo4j_database) + self.enhancement_mode = enhancement_mode + + def _apply_all_constraints_from_db( + self, + constraints: list[SchemaConstraint], + entities: tuple[GraphEntityType, ...], + ) -> list[GraphEntityType]: + constrained_entity_types = [] + for entity_type in entities: + new_entity_type = copy.deepcopy(entity_type) + # find constraints related to this node type + for constraint in constraints: + if constraint.entity_type != entity_type.entity_type_name: + continue + if constraint.label_or_type[0] != entity_type.label: + continue + # now we can add the constraint to this node type + self._apply_constraint_from_db(new_entity_type, constraint) + constrained_entity_types.append(new_entity_type) + return constrained_entity_types + @staticmethod - def create_schema_model( - node_types: Sequence[NodeType], - relationship_types: Optional[Sequence[RelationshipType]] = None, + def _parse_property_type(property_type: str) -> list[Neo4jPropertyType]: + prop_types = [] + if not property_type: + return prop_types + for prop_str in property_type.split("|"): + p = prop_str.strip() + try: + prop = Neo4jPropertyType(p) + prop_types.append(prop) + except ValueError: + pass + return prop_types + + def _apply_constraint_from_db( + self, entity_type: GraphEntityType, constraint: SchemaConstraint, + ) -> None: + """Validate that user schema is compatible with database constraints. + + This method now focuses on validation and only applies safe enhancements. + + Args: + entity_type: The entity type to validate and potentially enhance. + constraint: The database constraint to validate against. + + Raises: + SchemaDatabaseConflictError: If user schema conflicts with DB constraint. + """ + # Step 1: Validate compatibility (raises errors for conflicts) + self._validate_constraint_compatibility(entity_type, constraint) + + # Step 2: Apply only non-conflicting enhancements + # (Only set required=True if user didn't explicitly set it to False) + if constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE, + ): + for prop_name in constraint.properties: + prop = entity_type.get_property_by_name(prop_name) + if prop and not prop.required: + # This would have been caught by validation if it was a conflict + prop.required = True + + def _validate_constraint_compatibility( + self, + entity_type: GraphEntityType, + constraint: SchemaConstraint + ) -> None: + """Legacy method for backward compatibility. Uses shared logic.""" + missing_props = self._check_missing_properties(entity_type, constraint) + if missing_props: + raise SchemaDatabaseConflictError( + f"Database constraint {constraint.type} on {entity_type.label} " + f"requires properties {missing_props} that are not defined in user schema. " + f"Please add these properties to your {entity_type.label} definition or " + f"remove the constraint from the database." + ) + + # Check property type conflicts + if constraint.type in ( + Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE, + Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_TYPE + ): + type_conflicts = self._check_property_type_compatibility(entity_type, constraint) + if type_conflicts: + for prop_name, user_types, db_types in type_conflicts: + raise SchemaDatabaseConflictError( + f"Property '{prop_name}' on {entity_type.label} has type {user_types} " + f"in user schema, but database constraint allows only {db_types}. " + f"Please update the property type or remove the database constraint." + ) + + def _create_schema_model( + self, + node_types: Sequence[EntityInputType], + relationship_types: Optional[Sequence[RelationInputType]] = None, patterns: Optional[Sequence[Tuple[str, str, str]]] = None, **kwargs: Any, ) -> GraphSchema: @@ -348,9 +960,12 @@ def create_schema_model( Returns: GraphSchema: A configured schema object. + + Raises: + SchemaDatabaseConflictError: If enhancement_mode=False and conflicts found. """ try: - return GraphSchema.model_validate( + schema = GraphSchema.model_validate( dict( node_types=node_types, relationship_types=relationship_types or (), @@ -361,26 +976,49 @@ def create_schema_model( except ValidationError as e: raise SchemaValidationError() from e + # Use shared constraint processing logic + mode = "enhance" if self.enhancement_mode else "validate" + return self._process_constraints_against_schema(schema, mode=mode, **kwargs) + @validate_call async def run( self, - node_types: Sequence[NodeType], - relationship_types: Optional[Sequence[RelationshipType]] = None, + node_types: Sequence[EntityInputType], + relationship_types: Optional[Sequence[RelationInputType]] = None, patterns: Optional[Sequence[Tuple[str, str, str]]] = None, **kwargs: Any, ) -> GraphSchema: """ - Asynchronously constructs and returns a GraphSchema object. + Asynchronously constructs and returns a GraphSchema object with constraint processing. + + This method creates a schema from the provided entities and processes it against + database constraints according to the configured mode (validation or enhancement). Args: - node_types (Sequence[NodeType]): Sequence of NodeType objects. - relationship_types (Sequence[RelationshipType]): Sequence of RelationshipType objects. - patterns (Optional[Sequence[Tuple[str, str, str]]]): Sequence of triplets: (source_entity_label, relation_label, target_entity_label). + node_types: Sequence of NodeType objects defining the node types in the schema + relationship_types: Optional sequence of RelationshipType objects defining + the relationship types in the schema + patterns: Optional sequence of triplets (source_entity_label, relation_label, + target_entity_label) defining allowed relationship patterns + **kwargs: Additional configuration parameters including: + - additional_node_types (bool): Whether to allow additional node types + - additional_relationship_types (bool): Whether to allow additional relationship types + - additional_properties (bool): Whether entities can have additional properties Returns: - GraphSchema: A configured schema object, constructed asynchronously. + GraphSchema: A configured schema object. In validation mode, returns the schema + with safe enhancements (like required=True). In enhancement mode, returns a + schema modified to resolve all constraint conflicts. + + Raises: + SchemaDatabaseConflictError: If enhancement_mode=False and conflicts are found + SchemaValidationError: If the provided entity definitions are invalid + + Note: + The schema is automatically processed against database constraints using the + shared constraint processing logic from ConstraintProcessor. """ - return self.create_schema_model( + return self._create_schema_model( node_types, relationship_types, patterns, @@ -388,18 +1026,72 @@ async def run( ) -class SchemaFromTextExtractor(Component): +class SchemaFromTextExtractor(Component, ConstraintProcessor): """ - A component for constructing GraphSchema objects from the output of an LLM after - automatic schema extraction from text. + A component for automatically extracting GraphSchema objects from text using LLMs. + + This component uses a Large Language Model to analyze text and automatically extract + entity types, relationship types, and their properties. The extracted schema is then + automatically enhanced to match database constraints. + + SchemaFromTextExtractor always operates in enhancement mode - it never raises errors + for constraint conflicts. Instead, it automatically modifies the LLM-generated schema + to ensure database compatibility. + + Args: + driver: Neo4j driver instance for database access + llm: LLM instance implementing LLMInterface for schema extraction + prompt_template: Optional custom prompt template for schema extraction. + Defaults to SchemaExtractionTemplate. + llm_params: Optional dictionary of additional parameters to pass to the LLM + neo4j_database: Optional Neo4j database name. If None, uses default database. + + Example: + .. code-block:: python + + from neo4j_graphrag.experimental.components.schema import SchemaFromTextExtractor + from neo4j_graphrag.llm import OpenAILLM + + # Create the schema extractor + extractor = SchemaFromTextExtractor( + driver=neo4j_driver, + llm=OpenAILLM( + model_name="gpt-4o", + model_params={ + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + }, + ) + ) + + # Extract schema from text - automatically enhanced for database compatibility + text = "John works at Acme Corp. His email is john@acme.com" + schema = await extractor.run(text) + + # The schema will include entities and properties extracted by the LLM, + # plus any additional properties/types required by database constraints """ def __init__( self, + driver: neo4j.Driver, llm: LLMInterface, prompt_template: Optional[PromptTemplate] = None, llm_params: Optional[Dict[str, Any]] = None, + neo4j_database: Optional[str] = None, ) -> None: + """ + Initialize the SchemaFromTextExtractor. + + Args: + driver: Neo4j driver instance for database access + llm: LLM instance implementing LLMInterface for schema extraction + prompt_template: Optional custom prompt template for schema extraction. + Defaults to SchemaExtractionTemplate. + llm_params: Optional dictionary of additional parameters to pass to the LLM + neo4j_database: Optional Neo4j database name. If None, uses default database. + """ + super().__init__(driver, neo4j_database) self._llm: LLMInterface = llm self._prompt_template: PromptTemplate = ( prompt_template or SchemaExtractionTemplate() @@ -409,14 +1101,34 @@ def __init__( @validate_call async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ - Asynchronously extracts the schema from text and returns a GraphSchema object. + Asynchronously extracts schema from text and enhances it for database compatibility. + + This method uses the configured LLM to analyze the provided text and extract + entity types, relationship types, and their properties. The extracted schema + is then automatically enhanced to match database constraints. Args: - text (str): the text from which the schema will be inferred. - examples (str): examples to guide schema extraction. + text: The text from which the schema will be inferred + examples: Optional examples to guide schema extraction (few-shot learning) + **kwargs: Additional parameters for schema configuration including: + - additional_node_types (bool): Whether to allow additional node types (default: True) + - additional_relationship_types (bool): Whether to allow additional relationship types (default: True) + Returns: - GraphSchema: A configured schema object, extracted automatically and - constructed asynchronously. + GraphSchema: A configured schema object extracted from text and enhanced + to match database constraints. The schema includes: + - Entity types and properties identified by the LLM + - Additional properties/types required by database constraints + - Proper type annotations and required flags based on constraints + + Raises: + LLMGenerationError: If the LLM fails to generate a response + SchemaExtractionError: If the LLM response cannot be parsed or is invalid + + Note: + This component always operates in enhancement mode and will automatically + resolve any conflicts between the LLM-generated schema and database constraints. + If enhancement fails, the original LLM-generated schema is returned with a warning. """ prompt: str = self._prompt_template.format(text=text, examples=examples) @@ -428,7 +1140,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema raise LLMGenerationError("Failed to generate schema from text") from e try: - extracted_schema: Dict[str, Any] = json.loads(content) + extracted_schema = json.loads(content) # handle dictionary if isinstance(extracted_schema, dict): @@ -464,10 +1176,16 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "patterns" ) - return GraphSchema.model_validate( + # Create initial schema from LLM extraction + initial_schema = GraphSchema.model_validate( { "node_types": extracted_node_types, "relationship_types": extracted_relationship_types, "patterns": extracted_patterns, } ) + + # Enhance the schema to match database constraints using shared logic + enhanced_schema = self._process_constraints_against_schema(initial_schema, mode="enhance", **kwargs) + + return enhanced_schema diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 3c07d401c..b7168a814 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -14,12 +14,30 @@ # limitations under the License. from __future__ import annotations +import enum import uuid -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field, field_validator - +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Union, Tuple, Literal, Iterable +from typing_extensions import Self + +from pydantic import ( + BaseModel, + Field, + field_validator, + ValidationError, + model_validator, + ConfigDict, + PrivateAttr, +) + +from neo4j_graphrag.exceptions import SchemaValidationError from neo4j_graphrag.experimental.pipeline.component import DataModel +from neo4j_graphrag.experimental.pipeline.types.schema import ( + RelationInputType, + EntityInputType, +) +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat class DocumentInfo(DataModel): @@ -186,3 +204,324 @@ def lexical_graph_relationship_types(self) -> tuple[str, ...]: class GraphResult(DataModel): graph: Neo4jGraph config: LexicalGraphConfig + + +class Neo4jPropertyType(str, enum.Enum): + # See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types + BOOLEAN = "BOOLEAN" + DATE = "DATE" + DURATION = "DURATION" + FLOAT = "FLOAT" + INTEGER = "INTEGER" + LIST = "LIST" + LOCAL_DATETIME = "LOCAL_DATETIME" + LOCAL_TIME = "LOCAL_TIME" + POINT = "POINT" + STRING = "STRING" + ZONED_DATETIME = "ZONED_DATETIME" + ZONED_DATE = "ZONED_DATE" + + +class PropertyType(BaseModel): + """ + Represents a property on a node or relationship in the graph. + """ + + name: str + type: Neo4jPropertyType | list[Neo4jPropertyType] + description: str = "" + required: bool = False + + +class Neo4jConstraintTypeEnum(str, enum.Enum): + # see: https://neo4j.com/docs/cypher-manual/current/constraints/ + NODE_KEY = "NODE_KEY" + UNIQUENESS = "UNIQUENESS" + NODE_PROPERTY_EXISTENCE = "NODE_PROPERTY_EXISTENCE" + NODE_PROPERTY_UNIQUENESS = "NODE_PROPERTY_UNIQUENESS" + NODE_PROPERTY_TYPE = "NODE_PROPERTY_TYPE" + RELATIONSHIP_KEY = "RELATIONSHIP_KEY" + RELATIONSHIP_UNIQUENESS = "RELATIONSHIP_UNIQUENESS" + RELATIONSHIP_PROPERTY_EXISTENCE = "RELATIONSHIP_PROPERTY_EXISTENCE" + RELATIONSHIP_PROPERTY_UNIQUENESS = "RELATIONSHIP_PROPERTY_UNIQUENESS" + RELATIONSHIP_PROPERTY_TYPE = "RELATIONSHIP_PROPERTY_TYPE" + + +class SchemaConstraint(BaseModel): + """Constraints that can be applied either on a node or relationship property.""" + + entity_type: Literal["NODE", "RELATIONSHIP"] + label_or_type: list[str] + type: Neo4jConstraintTypeEnum + properties: list[str] + property_type: Optional[list[Neo4jPropertyType]] = None + name: Optional[str] = None # do not force users to set a name manually + + @field_validator("label_or_type", mode="before") + @classmethod + def _validate_label_or_type(cls, v: Any) -> Iterable[Any]: + if isinstance(v, str) or not isinstance(v, Iterable): + return [v] + return v + + @field_validator("properties", mode="before") + @classmethod + def _validate_properties(cls, v: Any) -> Iterable[Any]: + if isinstance(v, str) or not isinstance(v, Iterable): + return [v] + return v + + +class GraphEntityType(BaseModel): + """Represents a possible entity in the graph (node or relationship). + + They have a label and a list of properties. + + For LLM-based applications, it is also useful to add a description. + + The additional_properties flag is used in schema-driven data validation. + """ + + label: str + description: str = "" + properties: list[PropertyType] = [] + additional_properties: bool = True + + _entity_type_name: Literal["NODE", "RELATIONSHIP"] = PrivateAttr() + + @model_validator(mode="after") + def validate_additional_properties(self) -> Self: + if len(self.properties) == 0 and not self.additional_properties: + raise ValueError( + "Using `additional_properties=False` with no defined " + "properties will cause the model to be pruned during graph cleaning.", + ) + return self + + def get_property_by_name(self, name: str) -> PropertyType | None: + for prop in self.properties: + if prop.name == name: + return prop + return None + + @property + def entity_type_name(self) -> Literal["NODE", "RELATIONSHIP"]: + """Get the entity type name.""" + return self._entity_type_name + + @staticmethod + def unique_constraint_name() -> tuple[Neo4jConstraintTypeEnum, ...]: + raise NotImplementedError() + + +class NodeType(GraphEntityType): + """Represents a possible node in the graph.""" + + _entity_type_name: Literal["NODE", "RELATIONSHIP"] = PrivateAttr(default="NODE") + + @model_validator(mode="before") + @classmethod + def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: + if isinstance(data, str): + return {"label": data} + return data + + @staticmethod + def unique_constraint_name() -> tuple[Neo4jConstraintTypeEnum, ...]: + return ( + Neo4jConstraintTypeEnum.NODE_KEY, + Neo4jConstraintTypeEnum.UNIQUENESS, + ) + + +class RelationshipType(GraphEntityType): + """Represents a possible relationship between two nodes in the graph.""" + + _entity_type_name: Literal["NODE", "RELATIONSHIP"] = PrivateAttr(default="RELATIONSHIP") + + @model_validator(mode="before") + @classmethod + def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: + if isinstance(data, str): + return {"label": data} + return data + + @staticmethod + def unique_constraint_name() -> tuple[Neo4jConstraintTypeEnum, ...]: + return ( + Neo4jConstraintTypeEnum.RELATIONSHIP_KEY, + Neo4jConstraintTypeEnum.RELATIONSHIP_UNIQUENESS, + ) + + +class GraphSchema(DataModel): + """This model represents the expected + node and relationship types in the graph. + + It is used both for guiding the LLM in the entity and relation + extraction component, and for cleaning the extracted graph in a + post-processing step. + + .. warning:: + + This model is immutable. + """ + + node_types: Tuple[NodeType, ...] + relationship_types: Tuple[RelationshipType, ...] = tuple() + patterns: Tuple[Tuple[str, str, str], ...] = tuple() + constraints: Tuple[SchemaConstraint, ...] = tuple() + + additional_node_types: bool = True + additional_relationship_types: bool = True + additional_patterns: bool = True + + _node_type_index: dict[str, NodeType] = PrivateAttr() + _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() + + model_config = ConfigDict( + frozen=True, + ) + + @model_validator(mode="after") + def validate_patterns_against_node_and_rel_types(self) -> Self: + self._node_type_index = {node.label: node for node in self.node_types} + self._relationship_type_index = ( + {r.label: r for r in self.relationship_types} + if self.relationship_types + else {} + ) + + relationship_types = self.relationship_types + patterns = self.patterns + + if patterns: + if not relationship_types: + raise SchemaValidationError( + "Relationship types must also be provided when using patterns." + ) + for entity1, relation, entity2 in patterns: + if entity1 not in self._node_type_index: + raise SchemaValidationError( + f"Node type '{entity1}' is not defined in the provided node_types." + ) + if relation not in self._relationship_type_index: + raise SchemaValidationError( + f"Relationship type '{relation}' is not defined in the provided relationship_types." + ) + if entity2 not in self._node_type_index: + raise ValueError( + f"Node type '{entity2}' is not defined in the provided node_types." + ) + + return self + + @model_validator(mode="after") + def validate_additional_parameters(self) -> Self: + if ( + self.additional_patterns is False + and self.additional_relationship_types is True + ): + raise ValueError( + "`additional_relationship_types` must be set to False when using `additional_patterns=False`" + ) + return self + + @model_validator(mode="after") + def validate_constraint_on_properties(self) -> Self: + """Check that properties in constraints are listed in the property list.""" + for c in self.constraints: + entity: GraphEntityType | None = None + if c.entity_type == "NODE": + entity = self.node_type_from_label(c.label_or_type) + else: + entity = self.relationship_type_from_label(c.label_or_type) + if not entity: + raise ValueError(f"Entity type {c.label_or_type} is not defined.") + allowed_prop_names = [p.name for p in entity.properties] + for prop_name in c.properties: + if prop_name not in allowed_prop_names: + raise ValueError( + f"Property '{prop_name}' has a constraint '{c}' but is not in the property list for entity {entity}." + ) + return self + + def node_type_from_label(self, label: str) -> Optional[NodeType]: + return self._node_type_index.get(label) + + def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: + return self._relationship_type_index.get(label) + + def unique_properties_for_entity(self, entity: GraphEntityType) -> list[list[str]]: + result = [] + for c in self.constraints: + if c.entity_type != entity.entity_type_name: + continue + if c.label_or_type != entity.label: + continue + if c.type in entity.unique_constraint_name(): + result.append(c.properties) + return result + + def save( + self, + file_path: Union[str, Path], + overwrite: bool = False, + format: Optional[FileFormat] = None, + ) -> None: + """ + Save the schema configuration to file. + + Args: + file_path (str): The path where the schema configuration will be saved. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. + """ + data = self.model_dump(mode="json") + file_handler = FileHandler() + file_handler.write(data, file_path, overwrite=overwrite, format=format) + + def store_as_json( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) + + def store_as_yaml( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) + + @classmethod + def from_file( + cls, file_path: Union[str, Path], format: Optional[FileFormat] = None + ) -> Self: + """ + Load a schema configuration from a file (either JSON or YAML). + + The file format is automatically detected based on the file extension, + unless the format parameter is set. + + Args: + file_path (Union[str, Path]): The path to the schema configuration file. + format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). + + Returns: + GraphSchema: The loaded schema configuration. + """ + file_path = Path(file_path) + file_handler = FileHandler() + try: + data = file_handler.read(file_path, format=format) + except ValueError: + raise + + try: + return cls.model_validate(data) + except ValidationError as e: + raise SchemaValidationError(str(e)) from e diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 15bbd53df..0066bc0f9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -43,9 +43,9 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - GraphSchema, SchemaFromTextExtractor, ) +from neo4j_graphrag.experimental.components.types import GraphSchema from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -177,9 +177,10 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: Get the appropriate schema component based on configuration. Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema. """ + driver = self.get_default_neo4j_driver() if not self.has_user_provided_schema(): - return SchemaFromTextExtractor(llm=self.get_default_llm()) - return SchemaBuilder() + return SchemaFromTextExtractor(driver=driver, llm=self.get_default_llm()) + return SchemaBuilder(driver=driver) def _process_schema_with_precedence(self) -> dict[str, Any]: """ diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 891f57e04..67a2e1241 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -42,7 +42,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.experimental.components.schema import GraphSchema +from neo4j_graphrag.experimental.components.types import GraphSchema logger = logging.getLogger(__name__) diff --git a/src/neo4j_graphrag/schema.py b/src/neo4j_graphrag/schema.py index 40292067f..6baeebaca 100644 --- a/src/neo4j_graphrag/schema.py +++ b/src/neo4j_graphrag/schema.py @@ -180,6 +180,21 @@ def query_database( return json_data +def get_constraints( + driver: neo4j.Driver, + database: Optional[str] = None, + timeout: Optional[float] = None, + sanitize: bool = False, +) -> list[dict[str, Any]]: + return query_database( + driver=driver, + query="SHOW CONSTRAINTS", + database=database, + timeout=timeout, + sanitize=sanitize, + ) + + def get_schema( driver: neo4j.Driver, is_enhanced: bool = False, @@ -328,12 +343,8 @@ def get_structured_schema( # Get constraints and indexes try: - constraint = query_database( - driver=driver, - query="SHOW CONSTRAINTS", - database=database, - timeout=timeout, - sanitize=sanitize, + constraint = get_constraints( + driver=driver, database=database, timeout=timeout, sanitize=sanitize ) index = query_database( driver=driver, diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py index 333e74163..6673e24a9 100644 --- a/tests/e2e/experimental/test_graph_pruning_component_e2e.py +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -17,8 +17,8 @@ import pytest from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning -from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( + GraphSchema, Neo4jGraph, Neo4jNode, Neo4jRelationship, diff --git a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py index 89f5ae62c..f63ffd661 100644 --- a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py @@ -33,9 +33,11 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, - PropertyType, RelationshipType, + PropertyType, Neo4jPropertyType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -60,8 +62,8 @@ def embedder() -> Embedder: @pytest.fixture -def schema_builder() -> SchemaBuilder: - return SchemaBuilder() +def schema_builder(driver: neo4j.Driver) -> SchemaBuilder: + return SchemaBuilder(driver) @pytest.fixture @@ -191,27 +193,27 @@ async def test_pipeline_builder_happy_path( NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING"), - PropertyType(name="place_of_birth", type="STRING"), - PropertyType(name="date_of_birth", type="DATE"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), + PropertyType(name="place_of_birth", type=Neo4jPropertyType.STRING), + PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE), ], ), NodeType( label="Organization", properties=[ - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), ], ), NodeType( label="Potion", properties=[ - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type=Neo4jPropertyType.STRING), ], ), NodeType( label="Location", properties=[ - PropertyType(name="address", type="STRING"), + PropertyType(name="address", type=Neo4jPropertyType.STRING), ], ), ], diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 04730f2e4..936f3e64e 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -23,17 +23,16 @@ GraphPruningResult, PruningStats, ) -from neo4j_graphrag.experimental.components.schema import ( +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, RelationshipType, GraphSchema, -) -from neo4j_graphrag.experimental.components.types import ( Neo4jNode, Neo4jRelationship, Neo4jGraph, LexicalGraphConfig, + Neo4jPropertyType, ) @@ -111,8 +110,8 @@ def node_type_required_name() -> NodeType: return NodeType( label="Person", properties=[ - PropertyType(name="name", type="STRING", required=True), - PropertyType(name="age", type="INTEGER"), + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True), + PropertyType(name="age", type=Neo4jPropertyType.INTEGER), ], ) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index f6c53016e..1fbfc9f1a 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -15,20 +15,26 @@ from __future__ import annotations import json -from typing import Tuple +from typing import Tuple, Any from unittest.mock import AsyncMock, patch +import neo4j import pytest from pydantic import ValidationError -from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError +from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError, SchemaDatabaseConflictError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, + SchemaFromTextExtractor, +) +from neo4j_graphrag.experimental.components.types import ( NodeType, PropertyType, RelationshipType, - SchemaFromTextExtractor, GraphSchema, + Neo4jPropertyType, + SchemaConstraint, + Neo4jConstraintTypeEnum, ) import os import tempfile @@ -113,8 +119,8 @@ def valid_node_types() -> tuple[NodeType, ...]: label="PERSON", description="An individual human being.", properties=[ - PropertyType(name="birth date", type="ZONED_DATETIME"), - PropertyType(name="name", type="STRING", required=True), + PropertyType(name="birth date", type=Neo4jPropertyType.ZONED_DATE), + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True), ], additional_properties=False, ), @@ -133,8 +139,12 @@ def valid_relationship_types() -> tuple[RelationshipType, ...]: label="EMPLOYED_BY", description="Indicates employment relationship.", properties=[ - PropertyType(name="start_time", type="LOCAL_DATETIME", required=True), - PropertyType(name="end_time", type="LOCAL_DATETIME"), + PropertyType( + name="start_time", + type=Neo4jPropertyType.LOCAL_DATETIME, + required=True, + ), + PropertyType(name="end_time", type=Neo4jPropertyType.LOCAL_DATETIME), ], additional_properties=False, ), @@ -170,8 +180,8 @@ def patterns_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: @pytest.fixture -def schema_builder() -> SchemaBuilder: - return SchemaBuilder() +def schema_builder(driver: neo4j.Driver) -> SchemaBuilder: + return SchemaBuilder(driver) @pytest.fixture @@ -181,7 +191,7 @@ def graph_schema( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> GraphSchema: - return schema_builder.create_schema_model( + return schema_builder._create_schema_model( list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) @@ -192,7 +202,7 @@ def test_create_schema_model_valid_data( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema = schema_builder.create_schema_model( + schema = schema_builder._create_schema_model( list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) @@ -213,15 +223,18 @@ async def test_run_method( ) -> None: with patch.object( schema_builder, - "create_schema_model", + "_create_schema_model", return_value=GraphSchema( node_types=valid_node_types, relationship_types=valid_relationship_types, patterns=valid_patterns, ), ): + # Call with strings instead of NodeType objects schema = await schema_builder.run( - list(valid_node_types), list(valid_relationship_types), list(valid_patterns) + node_types=["PERSON", "ORGANIZATION", "AGE"], + relationship_types=["EMPLOYED_BY", "ORGANIZED_BY", "ATTENDED_BY"], + patterns=valid_patterns ) assert schema.node_types == valid_node_types @@ -239,7 +252,7 @@ def test_create_schema_model_invalid_entity( patterns_with_invalid_entity: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: - schema_builder.create_schema_model( + schema_builder._create_schema_model( list(valid_node_types), list(valid_relationship_types), list(patterns_with_invalid_entity), @@ -256,7 +269,7 @@ def test_create_schema_model_invalid_relation( patterns_with_invalid_relation: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: - schema_builder.create_schema_model( + schema_builder._create_schema_model( list(valid_node_types), list(valid_relationship_types), list(patterns_with_invalid_relation), @@ -271,7 +284,7 @@ def test_create_schema_model_no_potential_schema( valid_node_types: Tuple[NodeType, ...], valid_relationship_types: Tuple[RelationshipType, ...], ) -> None: - schema_instance = schema_builder.create_schema_model( + schema_instance = schema_builder._create_schema_model( list(valid_node_types), list(valid_relationship_types) ) assert schema_instance.node_types == valid_node_types @@ -283,7 +296,7 @@ def test_create_schema_model_no_relations_or_potential_schema( schema_builder: SchemaBuilder, valid_node_types: Tuple[NodeType, ...], ) -> None: - schema_instance = schema_builder.create_schema_model(list(valid_node_types)) + schema_instance = schema_builder._create_schema_model(list(valid_node_types)) assert len(schema_instance.node_types) == 3 person = schema_instance.node_type_from_label("PERSON") @@ -310,7 +323,7 @@ def test_create_schema_model_missing_relations( valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: - schema_builder.create_schema_model( + schema_builder._create_schema_model( node_types=valid_node_types, patterns=valid_patterns ) assert "Relationship types must also be provided when using patterns." in str( @@ -373,8 +386,10 @@ def invalid_schema_json() -> str: @pytest.fixture -def schema_from_text(mock_llm: AsyncMock) -> SchemaFromTextExtractor: - return SchemaFromTextExtractor(llm=mock_llm) +def schema_from_text( + driver: neo4j.Driver, mock_llm: AsyncMock +) -> SchemaFromTextExtractor: + return SchemaFromTextExtractor(driver, llm=mock_llm) @pytest.mark.asyncio @@ -434,18 +449,20 @@ async def test_schema_from_text_custom_template( # create SchemaFromTextExtractor with the custom template schema_from_text = SchemaFromTextExtractor( - llm=mock_llm, prompt_template=custom_template + driver=None, llm=mock_llm, prompt_template=custom_template ) # configure mock LLM to return valid JSON and capture the prompt that was sent to it mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) - # run the schema extraction - await schema_from_text.run(text="Sample text") + # Mock constraint retrieval to avoid database access + with patch.object(schema_from_text, '_get_constraints_from_db', return_value=[]): + # run the schema extraction + await schema_from_text.run(text="Sample text") - # verify the custom prompt was passed to the LLM - prompt_sent_to_llm = mock_llm.ainvoke.call_args[0][0] - assert "This is a custom prompt with text" in prompt_sent_to_llm + # verify the custom prompt was passed to the LLM + prompt_sent_to_llm = mock_llm.ainvoke.call_args[0][0] + assert "This is a custom prompt with text" in prompt_sent_to_llm @pytest.mark.asyncio @@ -456,19 +473,21 @@ async def test_schema_from_text_llm_params( llm_params = {"temperature": 0.1, "max_tokens": 500} # create SchemaFromTextExtractor with custom LLM parameters - schema_from_text = SchemaFromTextExtractor(llm=mock_llm, llm_params=llm_params) + schema_from_text = SchemaFromTextExtractor(driver=None, llm=mock_llm, llm_params=llm_params) # configure the mock LLM to return a valid schema JSON mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) - # run the schema extraction - await schema_from_text.run(text="Sample text") + # Mock constraint retrieval to avoid database access + with patch.object(schema_from_text, '_get_constraints_from_db', return_value=[]): + # run the schema extraction + await schema_from_text.run(text="Sample text") - # verify the LLM was called with the custom parameters - mock_llm.ainvoke.assert_called_once() - call_kwargs = mock_llm.ainvoke.call_args[1] - assert call_kwargs["temperature"] == 0.1 - assert call_kwargs["max_tokens"] == 500 + # verify the LLM was called with the custom parameters + mock_llm.ainvoke.assert_called_once() + call_kwargs = mock_llm.ainvoke.call_args[1] + assert call_kwargs["temperature"] == 0.1 + assert call_kwargs["max_tokens"] == 500 @pytest.mark.asyncio @@ -607,3 +626,643 @@ async def test_schema_from_text_run_valid_json_array( assert schema.patterns is not None assert len(schema.patterns) == 1 assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +# ==================== CONFLICT DETECTION TESTS ==================== + +@pytest.fixture +def mock_constraints_missing_property() -> list[SchemaConstraint]: + """Mock constraints that reference properties not in user schema.""" + return [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["missing_property"], + ) + ] + + +@pytest.fixture +def mock_constraints_type_conflict() -> list[SchemaConstraint]: + """Mock constraints with type conflicts.""" + return [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE, + properties=["name"], + property_type=[Neo4jPropertyType.INTEGER], # Conflicts with STRING + ) + ] + + +@pytest.fixture +def mock_constraints_required_conflict() -> list[SchemaConstraint]: + """Mock constraints requiring properties marked as optional by user.""" + return [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["optional_prop"], + ) + ] + + +@pytest.fixture +def mock_constraints_missing_entity() -> list[SchemaConstraint]: + """Mock constraints on entity types not in user schema.""" + return [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["UNKNOWN_LABEL"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["some_property"], + ) + ] + + +@pytest.fixture +def person_node_with_optional_prop() -> NodeType: + """Person node with an optional property that conflicts with DB requirements.""" + return NodeType( + label="PERSON", + properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True), + PropertyType(name="optional_prop", type=Neo4jPropertyType.STRING, required=False), + ], + ) + + +@pytest.fixture +def person_node_additional_props_false() -> NodeType: + """Person node with additional_properties=False.""" + return NodeType( + label="PERSON", + properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True), + ], + additional_properties=False, + ) + + +def test_missing_property_conflict( + schema_builder: SchemaBuilder, mock_constraints_missing_property: list[SchemaConstraint] +) -> None: + """Test that missing properties in user schema raise appropriate error.""" + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints_missing_property): + with pytest.raises(SchemaDatabaseConflictError, match="requires properties \\['missing_property'\\]"): + schema_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING) + ]) + ]) + + +def test_property_type_conflict( + schema_builder: SchemaBuilder, mock_constraints_type_conflict: list[SchemaConstraint] +) -> None: + """Test that property type conflicts raise appropriate error.""" + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints_type_conflict): + with pytest.raises(SchemaDatabaseConflictError, match="has type .* but database constraint allows only"): + schema_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING) # Conflicts with INTEGER + ]) + ]) + + +def test_required_property_conflict( + schema_builder: SchemaBuilder, + mock_constraints_required_conflict: list[SchemaConstraint], + person_node_with_optional_prop: NodeType +) -> None: + """Test that optional properties conflicting with DB existence constraints are enhanced, not errors.""" + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints_required_conflict): + # This should not raise an exception - we enhance instead of error + result = schema_builder._create_schema_model([person_node_with_optional_prop]) + + # Property should be enhanced to required=True + optional_prop = None + for prop in result.node_types[0].properties: + if prop.name == "optional_prop": + optional_prop = prop + break + + assert optional_prop is not None + assert optional_prop.required is True # Should be enhanced + + +def test_missing_entity_type_conflict( + schema_builder: SchemaBuilder, mock_constraints_missing_entity: list[SchemaConstraint] +) -> None: + """Test that missing entity types raise error when additional types are disabled.""" + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints_missing_entity): + with pytest.raises(SchemaDatabaseConflictError, match="Database has constraints on node labels"): + schema_builder._create_schema_model( + [NodeType(label="PERSON")], + additional_node_types=False + ) + + +def test_additional_properties_conflict( + schema_builder: SchemaBuilder, + person_node_additional_props_false: NodeType +) -> None: + """Test that additional_properties=False conflicts with DB-required properties.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["db_required_prop"], + ) + ] + + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints): + with pytest.raises(SchemaDatabaseConflictError, match="has additional_properties=False but database.*require"): + schema_builder._create_schema_model([person_node_additional_props_false]) + + +def test_no_conflict_with_compatible_schema( + schema_builder: SchemaBuilder +) -> None: + """Test that compatible schema and constraints work without errors.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["name"], + ) + ] + + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints): + # This should not raise any exceptions + result = schema_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True) + ]) + ]) + + assert len(result.node_types) == 1 + assert result.node_types[0].label == "PERSON" + # Property should remain required=True + assert result.node_types[0].properties[0].required is True + + +def test_enhancement_sets_required_property( + schema_builder: SchemaBuilder +) -> None: + """Test that existence constraints properly set required=True on properties.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["name"], + ) + ] + + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints): + result = schema_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING, required=False) + ]) + ]) + + # Property should be enhanced to required=True + assert result.node_types[0].properties[0].required is True + + +def test_compatible_property_types( + schema_builder: SchemaBuilder +) -> None: + """Test that compatible property types don't raise conflicts.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE, + properties=["name"], + property_type=[Neo4jPropertyType.STRING, Neo4jPropertyType.INTEGER], # Union type + ) + ] + + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints): + # User specifies STRING, DB allows STRING|INTEGER - should be compatible + result = schema_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type=Neo4jPropertyType.STRING) + ]) + ]) + + assert len(result.node_types) == 1 + assert result.node_types[0].properties[0].type == Neo4jPropertyType.STRING + + +def test_missing_entity_allowed_with_additional_types( + schema_builder: SchemaBuilder, mock_constraints_missing_entity: list[SchemaConstraint] +) -> None: + """Test that missing entity types are allowed when additional_*_types=True.""" + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints_missing_entity): + # This should not raise an exception because additional_node_types=True by default + result = schema_builder._create_schema_model([ + NodeType(label="PERSON") + ]) + + assert len(result.node_types) == 1 + assert result.node_types[0].label == "PERSON" + + +def test_relationship_constraint_conflicts( + schema_builder: SchemaBuilder +) -> None: + """Test conflict detection for relationship constraints.""" + mock_constraints = [ + SchemaConstraint( + entity_type="RELATIONSHIP", + label_or_type=["KNOWS"], + type=Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE, + properties=["missing_rel_prop"], + ) + ] + + with patch.object(schema_builder, '_get_constraints_from_db', return_value=mock_constraints): + with pytest.raises(SchemaDatabaseConflictError, match="requires properties \\['missing_rel_prop'\\]"): + schema_builder._create_schema_model( + [NodeType(label="PERSON")], + [RelationshipType(label="KNOWS", properties=[ + PropertyType(name="since", type=Neo4jPropertyType.DATE) + ])] + ) + + +# ==================== SCHEMA FROM TEXT EXTRACTOR ENHANCEMENT TESTS ==================== + +@pytest.fixture +def schema_from_text_extractor(driver: neo4j.Driver) -> SchemaFromTextExtractor: + """Fixture providing a SchemaFromTextExtractor instance.""" + from neo4j_graphrag.llm.base import LLMInterface + from neo4j_graphrag.llm.types import LLMResponse + + class MockLLM(LLMInterface): + def __init__(self, response_content: str): + super().__init__(model_name="mock-model") + self.response_content = response_content + + async def ainvoke(self, input_: str, **kwargs: Any) -> LLMResponse: + return LLMResponse(content=self.response_content) + + def invoke(self, input_: str, **kwargs: Any) -> LLMResponse: + return LLMResponse(content=self.response_content) + + # Mock LLM that returns a basic schema + mock_llm = MockLLM('{"node_types": [{"label": "Person", "properties": [{"name": "name", "type": "STRING"}]}], "relationship_types": [], "patterns": []}') + + return SchemaFromTextExtractor( + driver=driver, + llm=mock_llm + ) + + +def test_schema_enhancement_adds_missing_properties( + schema_from_text_extractor: SchemaFromTextExtractor, + mock_constraints_missing_property: list[SchemaConstraint] +) -> None: + """Test that enhancement adds missing properties required by constraints.""" + # Create a basic schema missing the required property + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints_missing_property): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that the missing property was added + person_node = enhanced_schema.node_type_from_label("PERSON") + missing_prop = person_node.get_property_by_name("missing_property") + + assert missing_prop is not None + assert missing_prop.required == True + assert "constraint" in missing_prop.description.lower() + + +def test_schema_enhancement_adds_missing_entity_types( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement adds missing entity types required by constraints.""" + # Mock constraints requiring an entity type not in the schema + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["ORGANIZATION"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["name"], + ) + ] + + # Schema without ORGANIZATION + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that ORGANIZATION was added + org_node = enhanced_schema.node_type_from_label("ORGANIZATION") + assert org_node is not None + assert "constraint" in org_node.description.lower() + + # Check that required property was added + name_prop = org_node.get_property_by_name("name") + assert name_prop is not None + assert name_prop.required == True + + +def test_schema_enhancement_updates_property_types( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement updates property types to match constraints.""" + # Mock type constraint + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE, + properties=["age"], + property_type=[Neo4jPropertyType.INTEGER] + ) + ] + + # Schema with wrong type + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="age", type="STRING") # Wrong type + ]) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that property type was updated + person_node = enhanced_schema.node_type_from_label("PERSON") + age_prop = person_node.get_property_by_name("age") + assert age_prop.type == Neo4jPropertyType.INTEGER + + +def test_schema_enhancement_sets_required_properties( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement sets required=True for existence constraints.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["email"], + ) + ] + + # Schema with optional property + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="email", type="STRING", required=False) + ]) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that property was made required + person_node = enhanced_schema.node_type_from_label("PERSON") + email_prop = person_node.get_property_by_name("email") + assert email_prop.required == True + + +def test_schema_enhancement_handles_relationship_constraints( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement works for relationship constraints.""" + mock_constraints = [ + SchemaConstraint( + entity_type="RELATIONSHIP", + label_or_type=["KNOWS"], + type=Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE, + properties=["since"], + ) + ] + + # Schema without the relationship property + initial_schema = GraphSchema( + node_types=[], + relationship_types=[ + RelationshipType(label="KNOWS", properties=[]) + ], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that property was added to relationship + knows_rel = enhanced_schema.relationship_type_from_label("KNOWS") + since_prop = knows_rel.get_property_by_name("since") + assert since_prop is not None + assert since_prop.required == True + + +def test_schema_enhancement_respects_additional_properties_false( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement respects additional_properties=False.""" + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["missing_prop"], + ) + ] + + # Schema with additional_properties=False + initial_schema = GraphSchema( + node_types=[ + NodeType( + label="PERSON", + properties=[PropertyType(name="name", type="STRING")], + additional_properties=False + ) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Check that property was NOT added due to additional_properties=False + person_node = enhanced_schema.node_type_from_label("PERSON") + missing_prop = person_node.get_property_by_name("missing_prop") + assert missing_prop is None + + +def test_schema_enhancement_handles_no_constraints( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement returns original schema when no constraints exist.""" + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ], + relationship_types=[], + patterns=[] + ) + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=[]): + enhanced_schema = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + + # Should return the same schema + assert enhanced_schema.model_dump() == initial_schema.model_dump() + + +@pytest.mark.asyncio +async def test_schema_from_text_extractor_run_with_enhancement( + driver: neo4j.Driver +) -> None: + """Test that the run method applies enhancement to LLM-generated schema.""" + from neo4j_graphrag.llm.base import LLMInterface + from neo4j_graphrag.llm.types import LLMResponse + + class MockLLM(LLMInterface): + def __init__(self): + super().__init__(model_name="mock-model") + + async def ainvoke(self, input_: str, **kwargs: Any) -> LLMResponse: + # Return a schema missing properties that will be required by constraints + return LLMResponse(content='{"node_types": [{"label": "PERSON", "properties": [{"name": "name", "type": "STRING"}]}], "relationship_types": [], "patterns": []}') + + def invoke(self, input_: str, **kwargs: Any) -> LLMResponse: + return LLMResponse(content='{"node_types": [{"label": "PERSON", "properties": [{"name": "name", "type": "STRING"}]}], "relationship_types": [], "patterns": []}') + + # Mock constraints that require additional properties + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["email"], + ) + ] + + extractor = SchemaFromTextExtractor( + driver=driver, + llm=MockLLM() + ) + + with patch.object(extractor, '_get_constraints_from_db', return_value=mock_constraints): + enhanced_schema = await extractor.run("Some text about persons") + + # Check that the schema was enhanced with the missing property + person_node = enhanced_schema.node_type_from_label("PERSON") + email_prop = person_node.get_property_by_name("email") + assert email_prop is not None + assert email_prop.required == True + assert "constraint" in email_prop.description.lower() + + +def test_schema_enhancement_graceful_failure( + schema_from_text_extractor: SchemaFromTextExtractor +) -> None: + """Test that enhancement fails gracefully and returns original schema.""" + initial_schema = GraphSchema( + node_types=[ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ], + relationship_types=[], + patterns=[] + ) + + # Mock constraints that will cause validation error + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["problematic_prop"], + ) + ] + + with patch.object(schema_from_text_extractor, '_get_constraints_from_db', return_value=mock_constraints): + with patch('neo4j_graphrag.experimental.components.schema.GraphSchema.model_validate', side_effect=ValidationError.from_exception_data("GraphSchema", [])): + # Should return original schema on validation failure + result = schema_from_text_extractor._process_constraints_against_schema(initial_schema, mode="enhance") + assert result.model_dump() == initial_schema.model_dump() + + +def test_schema_builder_enhancement_mode_flag(driver: neo4j.Driver) -> None: + """Test that SchemaBuilder enhancement_mode flag switches between validation and enhancement.""" + # Test validation mode (default) + validation_builder = SchemaBuilder(driver, enhancement_mode=False) + + mock_constraints = [ + SchemaConstraint( + entity_type="NODE", + label_or_type=["PERSON"], + type=Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE, + properties=["missing_property"], + ) + ] + + with patch.object(validation_builder, '_get_constraints_from_db', return_value=mock_constraints): + # Should raise error in validation mode + with pytest.raises(SchemaDatabaseConflictError, match="requires properties \\['missing_property'\\]"): + validation_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ]) + + # Test enhancement mode + enhancement_builder = SchemaBuilder(driver, enhancement_mode=True) + + with patch.object(enhancement_builder, '_get_constraints_from_db', return_value=mock_constraints): + # Should enhance schema instead of raising error + result = enhancement_builder._create_schema_model([ + NodeType(label="PERSON", properties=[ + PropertyType(name="name", type="STRING") + ]) + ]) + + # Check that missing property was added + person_node = result.node_type_from_label("PERSON") + missing_prop = person_node.get_property_by_name("missing_property") + assert missing_prop is not None + assert missing_prop.required == True + assert "constraint" in missing_prop.description.lower() diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 766469048..35904bb83 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -27,6 +27,8 @@ from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, SchemaFromTextExtractor, +) +from neo4j_graphrag.experimental.components.types import ( GraphSchema, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (