diff --git a/CHANGELOG.md b/CHANGELOG.md index b2fc5bcd9..378ec3c84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,17 @@ ### Changed +#### Strict mode + - Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema. +#### Schema definition + +- The `SchemaEntity` model has been renamed `NodeType`. +- The `SchemaRelation` model has been renamed `RelationshipType`. +- The `SchemaProperty` model has been renamed `PropertyType`. +- `SchemaConfig` has been removed in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). `entities`, `relations` and `potential_schema` fields have also been renamed `node_types`, `relationship_types` and `patterns` respectively. + ## 1.7.0 diff --git a/README.md b/README.md index 60677256c..85726c0aa 100644 --- a/README.md +++ b/README.md @@ -102,9 +102,9 @@ NEO4J_PASSWORD = "password" driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) # List the entities and relations the LLM should look for in the text -entities = ["Person", "House", "Planet"] -relations = ["PARENT_OF", "HEIR_OF", "RULES"] -potential_schema = [ +node_types = ["Person", "House", "Planet"] +relationship_types = ["PARENT_OF", "HEIR_OF", "RULES"] +patterns = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -128,8 +128,11 @@ kg_builder = SimpleKGPipeline( llm=llm, driver=driver, embedder=embedder, - entities=entities, - relations=relations, + schema={ + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, + }, on_error="IGNORE", from_pdf=False, ) @@ -365,7 +368,7 @@ When you're finished with your changes, create a pull request (PR) using the fol ## 🧪 Tests -To be able to run all tests, all extra packages needs to be installed. +To be able to run all tests, all extra packages needs to be installed. This is achieved by: ```bash diff --git a/docs/source/types.rst b/docs/source/types.rst index 73afcffb7..267e310d3 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -75,25 +75,25 @@ KGWriterModel .. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriterModel -SchemaProperty -============== +PropertyType +============ -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaProperty +.. autoclass:: neo4j_graphrag.experimental.components.schema.PropertyType -SchemaEntity -============ +NodeType +======== -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaEntity +.. autoclass:: neo4j_graphrag.experimental.components.schema.NodeType -SchemaRelation -============== +RelationshipType +================ -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaRelation +.. autoclass:: neo4j_graphrag.experimental.components.schema.RelationshipType -SchemaConfig -============ +GraphSchema +=========== -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaConfig +.. autoclass:: neo4j_graphrag.experimental.components.schema.GraphSchema LexicalGraphConfig =================== diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 3d02b47af..d7455a6f8 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -21,7 +21,7 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of - **Data loader**: extract text from files (PDFs, ...). - **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit). - **Chunk embedder** (optional): compute the chunk embeddings. -- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. +- **Schema builder**: provide a schema to ground the LLM extracted node and relationship types and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. - **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional). - **Entity and relation extractor**: extract relevant entities and relations from the text. - **Knowledge Graph writer**: save the identified entities and relations. @@ -73,18 +73,18 @@ Customizing the SimpleKGPipeline Graph Schema ------------ -It is possible to guide the LLM by supplying a list of entities, relationships, -and instructions on how to connect them. However, note that the extracted graph -may not fully adhere to these guidelines unless schema enforcement is enabled -(see :ref:`Schema Enforcement Behaviour`). Entities and relationships can be represented +It is possible to guide the LLM by supplying a list of node and relationship types, +and instructions on how to connect them (patterns). However, note that the extracted graph +may not fully adhere to these guidelines unless schema enforcement is enabled +(see :ref:`Schema Enforcement Behaviour`). Node and relationship types can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, it must include a label key and can optionally include description and properties keys, as shown below: .. code:: python - ENTITIES = [ - # entities can be defined with a simple label... + NODE_TYPES = [ + # node types can be defined with a simple label... "Person", # ... or with a dict if more details are needed, # such as a description: @@ -93,7 +93,7 @@ as shown below: {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: - RELATIONS = [ + RELATIONSHIP_TYPES = [ "PARENT_OF", { "label": "HEIR_OF", @@ -102,13 +102,13 @@ as shown below: {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, ] -The `potential_schema` is defined by a list of triplet in the format: +The `patterns` are defined by a list of triplet in the format: `(source_node_label, relationship_label, target_node_label)`. For instance: .. code:: python - POTENTIAL_SCHEMA = [ + PATTERNS = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -122,15 +122,15 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated kg_builder = SimpleKGPipeline( # ... schema={ - "entities": ENTITIES, - "relations": RELATIONS, - "potential_schema": POTENTIAL_SCHEMA + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS }, # ... ) .. note:: - By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromTextExtractor`). + By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction`). Extra configurations -------------------- @@ -419,9 +419,8 @@ within the configuration file. "neo4j_database": "myDb", "on_error": "IGNORE", "prompt_template": "...", - "schema": { - "entities": [ + "node_types": [ "Person", { "label": "House", @@ -438,7 +437,7 @@ within the configuration file. ] } ], - "relations": [ + "relationship_types": [ "PARENT_OF", { "label": "HEIR_OF", @@ -451,7 +450,7 @@ within the configuration file. ] } ], - "potential_schema": [ + "patterns": [ ["Person", "PARENT_OF", "Person"], ["Person", "HEIR_OF", "House"], ["House", "RULES", "Planet"] @@ -473,7 +472,7 @@ or in YAML: on_error: IGNORE prompt_template: ... schema: - entities: + node_types: - Person - label: House description: Family the person belongs to @@ -486,7 +485,7 @@ or in YAML: type: STRING - name: weather type: STRING - relations: + relationship_types: - PARENT_OF - label: HEIR_OF description: Used for inheritor relationship between father and sons @@ -494,7 +493,7 @@ or in YAML: properties: - name: fromYear type: INTEGER - potential_schema: + patterns: - ["Person", "PARENT_OF", "Person"] - ["Person", "HEIR_OF", "House"] - ["House", "RULES", "Planet"] @@ -747,12 +746,12 @@ Optionally, the document and chunk node labels can be configured using a `Lexica Schema Builder ============== -The schema is used to try and ground the LLM to a list of possible entities and relations of interest. +The schema is used to try and ground the LLM to a list of possible node and relationship types of interest. So far, schema must be manually created by specifying: -- **Entities** the LLM should look for in the text, including their properties (name and type). -- **Relations** of interest between these entities, including the relation properties (name and type). -- **Triplets** to define the start (source) and end (target) entity types for each relation. +- **Node types** the LLM should look for in the text, including their properties (name and type). +- **Relationship types** of interest between these node types, including the relationship properties (name and type). +- **Patterns** (triplets) to define the start (source) and end (target) entity types for each relationship. Here is a code block illustrating these concepts: @@ -760,16 +759,16 @@ Here is a code block illustrating these concepts: from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) schema_builder = SchemaBuilder() await schema_builder.run( - entities=[ - SchemaEntity( + node_types=[ + NodeType( label="Person", properties=[ SchemaProperty(name="name", type="STRING"), @@ -777,7 +776,7 @@ Here is a code block illustrating these concepts: SchemaProperty(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ SchemaProperty(name="name", type="STRING"), @@ -785,24 +784,24 @@ Here is a code block illustrating these concepts: ], ), ], - relations=[ - SchemaRelation( + relationship_types=[ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - possible_schema=[ + patterns=[ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], ) -After validation, this schema is saved in a `SchemaConfig` object, whose dict representation is passed +After validation, this schema is saved in a `GraphSchema` object, whose dict representation is passed to the LLM. -Automatic Schema Extraction +Automatic Schema Extraction --------------------------- Instead of manually defining the schema, you can use the `SchemaFromTextExtractor` component to automatically extract a schema from your text using an LLM: @@ -826,19 +825,19 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto # Extract the schema from the text extracted_schema = await schema_extractor.run(text="Some text") -The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. +The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: .. code:: python # Save the schema to JSON or YAML files - schema_config.store_as_json("my_schema.json") - schema_config.store_as_yaml("my_schema.yaml") - + extracted_schema.store_as_json("my_schema.json") + extracted_schema.store_as_yaml("my_schema.yaml") + # Later, reload the schema from file - from neo4j_graphrag.experimental.components.schema import SchemaConfig - restored_schema = SchemaConfig.from_file("my_schema.json") # or my_schema.yaml + from neo4j_graphrag.experimental.components.schema import GraphSchema + restored_schema = GraphSchema.from_file("my_schema.json") # or my_schema.yaml Entity and Relation Extractor @@ -993,7 +992,6 @@ If more customization is needed, it is possible to subclass the `EntityRelationE from pydantic import validate_call from neo4j_graphrag.experimental.components.entity_relation_extractor import EntityRelationExtractor - from neo4j_graphrag.experimental.components.schema import SchemaConfig from neo4j_graphrag.experimental.components.types import ( Neo4jGraph, Neo4jNode, diff --git a/examples/build_graph/simple_kg_builder_from_pdf.py b/examples/build_graph/simple_kg_builder_from_pdf.py index f7ad683da..2cfc85134 100644 --- a/examples/build_graph/simple_kg_builder_from_pdf.py +++ b/examples/build_graph/simple_kg_builder_from_pdf.py @@ -27,11 +27,11 @@ file_path = root_dir / "data" / "Harry Potter and the Chamber of Secrets Summary.pdf" -# Instantiate Entity and Relation objects. This defines the +# Instantiate NodeType and RelationshipType objects. This defines the # entities and relations the LLM will be looking for in the text. -ENTITIES = ["Person", "Organization", "Location"] -RELATIONS = ["SITUATED_AT", "INTERACTS", "LED_BY"] -POTENTIAL_SCHEMA = [ +NODE_TYPES = ["Person", "Organization", "Location"] +RELATIONSHIP_TYPES = ["SITUATED_AT", "INTERACTS", "LED_BY"] +PATTERNS = [ ("Person", "SITUATED_AT", "Location"), ("Person", "INTERACTS", "Person"), ("Organization", "LED_BY", "Person"), @@ -47,9 +47,11 @@ async def define_and_run_pipeline( llm=llm, driver=neo4j_driver, embedder=OpenAIEmbeddings(), - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, + schema={ + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS, + }, neo4j_database=DATABASE, ) return await kg_builder.run_async(file_path=str(file_path)) diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index 79b8c8791..548cbd9eb 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -37,7 +37,7 @@ # Instantiate Entity and Relation objects. This defines the # entities and relations the LLM will be looking for in the text. -ENTITIES: list[EntityInputType] = [ +NODE_TYPES: list[EntityInputType] = [ # entities can be defined with a simple label... "Person", # ... or with a dict if more details are needed, @@ -47,7 +47,7 @@ {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: -RELATIONS: list[RelationInputType] = [ +RELATIONSHIP_TYPES: list[RelationInputType] = [ "PARENT_OF", { "label": "HEIR_OF", @@ -55,7 +55,7 @@ }, {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, ] -POTENTIAL_SCHEMA = [ +PATTERNS = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -71,9 +71,11 @@ async def define_and_run_pipeline( llm=llm, driver=neo4j_driver, embedder=OpenAIEmbeddings(), - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, + schema={ + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS, + }, from_pdf=False, neo4j_database=DATABASE, ) diff --git a/examples/customize/build_graph/components/schema_builders/schema.py b/examples/customize/build_graph/components/schema_builders/schema.py index 6333fdd97..6ca408dee 100644 --- a/examples/customize/build_graph/components/schema_builders/schema.py +++ b/examples/customize/build_graph/components/schema_builders/schema.py @@ -14,43 +14,44 @@ # limitations under the License. from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) async def main() -> None: schema_builder = SchemaBuilder() - await schema_builder.run( - entities=[ - SchemaEntity( + result = await schema_builder.run( + node_types=[ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), ], - relations=[ - SchemaRelation( + relationship_types=[ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - potential_schema=[ + patterns=[ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], ) + print(result) diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index 4396f3fb6..a36ef90ec 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -14,7 +14,7 @@ from neo4j_graphrag.experimental.components.schema import ( SchemaFromTextExtractor, - SchemaConfig, + GraphSchema, ) from neo4j_graphrag.llm import OpenAILLM @@ -27,25 +27,25 @@ # Sample text to extract schema from - it's about a company and its employees TEXT = """ -Acme Corporation was founded in 1985 by John Smith in New York City. -The company specializes in manufacturing high-quality widgets and gadgets +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets for the consumer electronics industry. -Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to -Engineering Director in 2015. She oversees a team of 12 engineers working on -next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT and has filed 5 patents during her time at Acme. -The company expanded to international markets in 2012, opening offices in London, -Tokyo, and Berlin. Each office is managed by a regional director who reports +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports directly to the CEO, Michael Brown, who took over leadership in 2008. -Acme's most successful product, the SuperWidget X1, was launched in 2018 and -has sold over 2 million units worldwide. The product was developed by a team led +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. -The company currently employs 250 people across its 4 locations and had a revenue -of $75 million in the last fiscal year. Acme is planning to go public in 2024 +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 with an estimated valuation of $500 million. """ @@ -92,14 +92,14 @@ async def extract_and_save_schema() -> None: inferred_schema.store_as_yaml(YAML_FILE_PATH) print("\nExtracted Schema Summary:") - print(f"Entities: {list(inferred_schema.entities.keys())}") + print(f"Node types: {list(inferred_schema.node_types)}") print( - f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}" + f"Relationship types: {list(inferred_schema.relationship_types if inferred_schema.relationship_types else [])}" ) - if inferred_schema.potential_schema: - print("\nPotential Schema:") - for entity1, relation, entity2 in inferred_schema.potential_schema: + if inferred_schema.patterns: + print("\nPatterns:") + for entity1, relation, entity2 in inferred_schema.patterns: print(f" {entity1} --[{relation}]--> {entity2}") finally: @@ -119,11 +119,11 @@ async def main() -> None: # load schema from files print("\nLoading schemas from saved files:") - schema_from_json = SchemaConfig.from_file(JSON_FILE_PATH) - schema_from_yaml = SchemaConfig.from_file(YAML_FILE_PATH) + schema_from_json = GraphSchema.from_file(JSON_FILE_PATH) + schema_from_yaml = GraphSchema.from_file(YAML_FILE_PATH) - print(f"Entities in JSON schema: {list(schema_from_json.entities.keys())}") - print(f"Entities in YAML schema: {list(schema_from_yaml.entities.keys())}") + print(f"Node types in JSON schema: {list(schema_from_json.node_types)}") + print(f"Node types in YAML schema: {list(schema_from_yaml.node_types)}") if __name__ == "__main__": diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index ab11206da..ea727fe3c 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -25,8 +25,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -45,35 +45,35 @@ async def define_and_run_pipeline( from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( + node_types = [ + NodeType(label="PERSON", description="An individual human being."), + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="LOCATION", description="A location or place."), - SchemaEntity( + NodeType(label="LOCATION", description="A location or place."), + NodeType( label="HORCRUX", description="A magical item in the Harry Potter universe.", ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="SITUATED_AT", description="Indicates the location of a person." ), - SchemaRelation( + RelationshipType( label="LED_BY", description="Indicates the leader of an organization.", ), - SchemaRelation( + RelationshipType( label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation( + RelationshipType( label="INTERACTS", description="The interaction between two people." ), ] - potential_schema = [ + patterns = [ ("PERSON", "SITUATED_AT", "LOCATION"), ("PERSON", "INTERACTS", "PERSON"), ("PERSON", "OWNS", "HORCRUX"), @@ -114,12 +114,12 @@ async def define_and_run_pipeline( pipe_inputs = { "pdf_loader": { - "filepath": "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf" + "filepath": "examples/data/Harry Potter and the Death Hallows Summary.pdf" }, "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, } return await pipe.run(pipe_inputs) diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 907a02825..3a9e30911 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -25,9 +25,9 @@ from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -95,38 +95,38 @@ async def define_and_run_pipeline( the University of Bern in Switzerland and the University of Oxford.""" }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py index d6f5e9ae8..eda2b4219 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py +++ b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py @@ -28,9 +28,9 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -92,35 +92,35 @@ async def define_and_run_pipeline( pipe_inputs = { "loader": {}, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_FOR", ), - SchemaRelation( + RelationshipType( label="FRIEND", ), - SchemaRelation( + RelationshipType( label="ENEMY", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_FOR", "Organization"), ("Person", "FRIEND", "Person"), ("Person", "ENEMY", "Person"), @@ -129,8 +129,8 @@ async def define_and_run_pipeline( } # run the pipeline for each documents for document in [ - "examples/pipeline/Harry Potter and the Chamber of Secrets Summary.pdf", - "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf", + "examples/data/Harry Potter and the Chamber of Secrets Summary.pdf", + "examples/data/Harry Potter and the Death Hallows Summary.pdf", ]: pipe_inputs["loader"]["filepath"] = document await pipe.run(pipe_inputs) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index 6867d9068..daaab51a5 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -16,9 +16,9 @@ from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -118,38 +118,38 @@ async def define_and_run_pipeline( }, }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index 0fd354db4..b5b6b5273 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -18,9 +18,9 @@ from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -138,38 +138,38 @@ async def read_chunk_and_perform_entity_extraction( "lexical_graph_config": lexical_graph_config, }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/data/extracted_schema.json b/examples/data/extracted_schema.json index 0ec66639b..ce36153d6 100644 --- a/examples/data/extracted_schema.json +++ b/examples/data/extracted_schema.json @@ -1,7 +1,7 @@ { - "entities": { - "Company": { - "label": "Company", + "node_types": [ + { + "label": "Person", "description": "", "properties": [ { @@ -10,24 +10,19 @@ "description": "" }, { - "name": "foundedYear", - "type": "INTEGER", - "description": "" - }, - { - "name": "revenue", - "type": "FLOAT", + "name": "position", + "type": "STRING", "description": "" }, { - "name": "valuation", - "type": "FLOAT", + "name": "startYear", + "type": "INTEGER", "description": "" } ] }, - "Person": { - "label": "Person", + { + "label": "Company", "description": "", "properties": [ { @@ -36,18 +31,23 @@ "description": "" }, { - "name": "position", - "type": "STRING", + "name": "foundedYear", + "type": "INTEGER", "description": "" }, { - "name": "yearJoined", - "type": "INTEGER", + "name": "revenue", + "type": "FLOAT", + "description": "" + }, + { + "name": "valuation", + "type": "FLOAT", "description": "" } ] }, - "Product": { + { "label": "Product", "description": "", "properties": [ @@ -68,7 +68,7 @@ } ] }, - "Office": { + { "label": "Office", "description": "", "properties": [ @@ -79,40 +79,30 @@ } ] } - }, - "relations": { - "FOUNDED_BY": { - "label": "FOUNDED_BY", - "description": "", - "properties": [] - }, - "WORKS_FOR": { + ], + "relationship_types": [ + { "label": "WORKS_FOR", "description": "", "properties": [] }, - "MANAGES": { + { "label": "MANAGES", "description": "", "properties": [] }, - "DEVELOPED_BY": { + { "label": "DEVELOPED_BY", "description": "", "properties": [] }, - "LOCATED_IN": { + { "label": "LOCATED_IN", "description": "", "properties": [] } - }, - "potential_schema": [ - [ - "Company", - "FOUNDED_BY", - "Person" - ], + ], + "patterns": [ [ "Person", "WORKS_FOR", diff --git a/examples/data/extracted_schema.yaml b/examples/data/extracted_schema.yaml index f2500799f..efdd8e733 100644 --- a/examples/data/extracted_schema.yaml +++ b/examples/data/extracted_schema.yaml @@ -1,78 +1,63 @@ -entities: - Company: - label: Company +node_types: +- label: Person + description: '' + properties: + - name: name + type: STRING description: '' - properties: - - name: name - type: STRING - description: '' - - name: foundedYear - type: INTEGER - description: '' - - name: revenue - type: FLOAT - description: '' - - name: valuation - type: FLOAT - description: '' - Person: - label: Person + - name: position + type: STRING description: '' - properties: - - name: name - type: STRING - description: '' - - name: position - type: STRING - description: '' - - name: yearJoined - type: INTEGER - description: '' - Product: - label: Product + - name: startYear + type: INTEGER description: '' - properties: - - name: name - type: STRING - description: '' - - name: launchYear - type: INTEGER - description: '' - - name: unitsSold - type: INTEGER - description: '' - Office: - label: Office +- label: Company + description: '' + properties: + - name: name + type: STRING description: '' - properties: - - name: location - type: STRING - description: '' -relations: - FOUNDED_BY: - label: FOUNDED_BY + - name: foundedYear + type: INTEGER description: '' - properties: [] - WORKS_FOR: - label: WORKS_FOR + - name: revenue + type: FLOAT description: '' - properties: [] - MANAGES: - label: MANAGES + - name: valuation + type: FLOAT description: '' - properties: [] - DEVELOPED_BY: - label: DEVELOPED_BY +- label: Product + description: '' + properties: + - name: name + type: STRING description: '' - properties: [] - LOCATED_IN: - label: LOCATED_IN + - name: launchYear + type: INTEGER description: '' - properties: [] -potential_schema: -- - Company - - FOUNDED_BY - - Person + - name: unitsSold + type: INTEGER + description: '' +- label: Office + description: '' + properties: + - name: location + type: STRING + description: '' +relationship_types: +- label: WORKS_FOR + description: '' + properties: [] +- label: MANAGES + description: '' + properties: [] +- label: DEVELOPED_BY + description: '' + properties: [] +- label: LOCATED_IN + description: '' + properties: [] +patterns: - - Person - WORKS_FOR - Company diff --git a/examples/kg_builder.py b/examples/kg_builder.py index 650473e41..c98f0c069 100644 --- a/examples/kg_builder.py +++ b/examples/kg_builder.py @@ -32,8 +32,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -50,35 +50,35 @@ async def define_and_run_pipeline( from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( + node_types = [ + NodeType(label="PERSON", description="An individual human being."), + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="LOCATION", description="A location or place."), - SchemaEntity( + NodeType(label="LOCATION", description="A location or place."), + NodeType( label="HORCRUX", description="A magical item in the Harry Potter universe.", ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="SITUATED_AT", description="Indicates the location of a person." ), - SchemaRelation( + RelationshipType( label="LED_BY", description="Indicates the leader of an organization.", ), - SchemaRelation( + RelationshipType( label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation( + RelationshipType( label="INTERACTS", description="The interaction between two people." ), ] - potential_schema = [ + patterns = [ ("PERSON", "SITUATED_AT", "LOCATION"), ("PERSON", "INTERACTS", "PERSON"), ("PERSON", "OWNS", "HORCRUX"), @@ -121,9 +121,9 @@ async def define_and_run_pipeline( "filepath": "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf" }, "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, } return await pipe.run(pipe_inputs) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index d041d78e4..1d29fefe2 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -25,7 +25,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema, PropertyType from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, @@ -209,7 +209,7 @@ def __init__( self.prompt_template = template async def extract_for_chunk( - self, schema: SchemaConfig, examples: str, chunk: TextChunk + self, schema: GraphSchema, examples: str, chunk: TextChunk ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( @@ -275,7 +275,7 @@ async def run_for_chunk( self, sem: asyncio.Semaphore, chunk: TextChunk, - schema: SchemaConfig, + schema: GraphSchema, examples: str, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> Neo4jGraph: @@ -296,7 +296,7 @@ async def run( chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, - schema: Union[SchemaConfig, None] = None, + schema: Union[GraphSchema, None] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: @@ -311,7 +311,7 @@ async def run( chunks (TextChunks): List of text chunks to extract entities and relations from. document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step. lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. - schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. + schema (GraphSchema | None): Definition of the schema to guide the LLM in its extraction. examples (str): Examples for few-shot learning in the prompt. """ lexical_graph_builder = None @@ -325,7 +325,9 @@ async def run( lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) - schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[]) + schema = schema or GraphSchema( + node_types=(), relationship_types=(), patterns=() + ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ @@ -344,14 +346,14 @@ async def run( return graph def validate_chunk( - self, chunk_graph: Neo4jGraph, schema: SchemaConfig + self, chunk_graph: Neo4jGraph, schema: GraphSchema ) -> Neo4jGraph: """ Perform validation after entity and relation extraction: - Enforce schema if schema enforcement mode is on and schema is provided """ if self.enforce_schema != SchemaEnforcementMode.NONE: - if not schema or not schema.entities: # schema is not provided + if not schema or not schema.node_types: # schema is not provided logger.warning( "Schema enforcement is ON but the guiding schema is not provided." ) @@ -363,7 +365,7 @@ def validate_chunk( def _clean_graph( self, graph: Neo4jGraph, - schema: SchemaConfig, + schema: GraphSchema, ) -> Neo4jGraph: """ Verify that the graph conforms to the provided schema. @@ -385,7 +387,7 @@ def _clean_graph( return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) def _enforce_nodes( - self, extracted_nodes: List[Neo4jNode], schema: SchemaConfig + self, extracted_nodes: List[Neo4jNode], schema: GraphSchema ) -> List[Neo4jNode]: """ Filter extracted nodes to be conformant to the schema. @@ -400,10 +402,10 @@ def _enforce_nodes( valid_nodes = [] for node in extracted_nodes: - schema_entity = schema.entities.get(node.label) + schema_entity = schema.node_type_from_label(node.label) if not schema_entity: continue - allowed_props = schema_entity.get("properties") + allowed_props = schema_entity.properties or [] if allowed_props: filtered_props = self._enforce_properties( node.properties, allowed_props @@ -426,7 +428,7 @@ def _enforce_relationships( self, extracted_relationships: List[Neo4jRelationship], filtered_nodes: List[Neo4jNode], - schema: SchemaConfig, + schema: GraphSchema, ) -> List[Neo4jRelationship]: """ Filter extracted nodes to be conformant to the schema. @@ -439,42 +441,47 @@ def _enforce_relationships( if self.enforce_schema != SchemaEnforcementMode.STRICT: return extracted_relationships - if schema.relations is None: + if schema.relationship_types is None: return extracted_relationships valid_rels = [] valid_nodes = {node.id: node.label for node in filtered_nodes} - potential_schema = schema.potential_schema + patterns = schema.patterns for rel in extracted_relationships: - schema_relation = schema.relations.get(rel.type) + schema_relation = schema.relationship_type_from_label(rel.type) if not schema_relation: + logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") continue if ( rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes ): + logger.debug( + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" + ) continue start_label = valid_nodes[rel.start_node_id] end_label = valid_nodes[rel.end_node_id] tuple_valid = True - if potential_schema: - tuple_valid = (start_label, rel.type, end_label) in potential_schema + if patterns: + tuple_valid = (start_label, rel.type, end_label) in patterns reverse_tuple_valid = ( end_label, rel.type, start_label, - ) in potential_schema + ) in patterns if not tuple_valid and not reverse_tuple_valid: + logger.debug(f"PRUNING:: {rel} not in the potential schema") continue - allowed_props = schema_relation.get("properties") + allowed_props = schema_relation.properties or [] if allowed_props: filtered_props = self._enforce_properties(rel.properties, allowed_props) else: @@ -493,13 +500,13 @@ def _enforce_relationships( return valid_rels def _enforce_properties( - self, properties: Dict[str, Any], valid_properties: List[Dict[str, Any]] + self, properties: Dict[str, Any], valid_properties: List[PropertyType] ) -> Dict[str, Any]: """ Filter properties. Keep only those that exist in schema (i.e., valid properties). """ - valid_prop_names = {prop["name"] for prop in valid_properties} + valid_prop_names = {prop.name for prop in valid_properties} return { key: value for key, value in properties.items() if key in valid_prop_names } diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 2e0641c87..1136c6db3 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,10 +17,17 @@ import json import yaml import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence from pathlib import Path -from pydantic import BaseModel, ValidationError, model_validator, validate_call +from pydantic import ( + BaseModel, + PrivateAttr, + ValidationError, + model_validator, + validate_call, + ConfigDict, +) from typing_extensions import Self from neo4j_graphrag.exceptions import ( @@ -37,7 +44,7 @@ from neo4j_graphrag.llm import LLMInterface -class SchemaProperty(BaseModel): +class PropertyType(BaseModel): """ Represents a property on a node or relationship in the graph. """ @@ -60,78 +67,97 @@ class SchemaProperty(BaseModel): ] description: str = "" + model_config = ConfigDict( + frozen=True, + ) + -class SchemaEntity(BaseModel): +class NodeType(BaseModel): """ Represents a possible node in the graph. """ label: str description: str = "" - properties: List[SchemaProperty] = [] + properties: list[PropertyType] = [] @classmethod def from_text_or_dict(cls, input: EntityInputType) -> Self: - if isinstance(input, SchemaEntity): + if isinstance(input, NodeType): return input if isinstance(input, str): return cls(label=input) return cls.model_validate(input) -class SchemaRelation(BaseModel): +class RelationshipType(BaseModel): """ Represents a possible relationship between nodes in the graph. """ label: str description: str = "" - properties: List[SchemaProperty] = [] + properties: list[PropertyType] = [] @classmethod def from_text_or_dict(cls, input: RelationInputType) -> Self: - if isinstance(input, SchemaRelation): + if isinstance(input, RelationshipType): return input if isinstance(input, str): return cls(label=input) return cls.model_validate(input) -class SchemaConfig(DataModel): - """ - Represents possible relationships between entities and relations in the graph. - """ +class GraphSchema(DataModel): + node_types: Tuple[NodeType, ...] + relationship_types: Optional[Tuple[RelationshipType, ...]] = None + patterns: Optional[Tuple[Tuple[str, str, str], ...]] = None - entities: Dict[str, Dict[str, Any]] - relations: Optional[Dict[str, Dict[str, Any]]] - potential_schema: Optional[List[Tuple[str, str, str]]] + _node_type_index: dict[str, NodeType] = PrivateAttr() + _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() - @model_validator(mode="before") - def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: - entities = data.get("entities", {}).keys() - relations = (data.get("relations") or {}).keys() - potential_schema = data.get("potential_schema", []) + model_config = ConfigDict( + frozen=True, + ) + + @model_validator(mode="after") + def check_schema(self) -> Self: + self._node_type_index = {node.label: node for node in self.node_types} + self._relationship_type_index = ( + {r.label: r for r in self.relationship_types} + if self.relationship_types + else {} + ) - if potential_schema: - if not relations: + relationship_types = self.relationship_types or tuple() + patterns = self.patterns or tuple() + + if patterns: + if not relationship_types: raise SchemaValidationError( "Relations must also be provided when using a potential schema." ) - for entity1, relation, entity2 in potential_schema: - if entity1 not in entities: + for entity1, relation, entity2 in patterns: + if entity1 not in self._node_type_index: raise SchemaValidationError( f"Entity '{entity1}' is not defined in the provided entities." ) - if relation not in relations: + if relation not in self._relationship_type_index: raise SchemaValidationError( f"Relation '{relation}' is not defined in the provided relations." ) - if entity2 not in entities: + if entity2 not in self._node_type_index: raise SchemaValidationError( f"Entity '{entity2}' is not defined in the provided entities." ) - return data + return self + + def node_type_from_label(self, label: str) -> Optional[NodeType]: + return self._node_type_index.get(label) + + def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: + return self._relationship_type_index.get(label) def store_as_json(self, file_path: str) -> None: """ @@ -152,8 +178,12 @@ def store_as_yaml(self, file_path: str) -> None: """ # create a copy of the data and convert tuples to lists for YAML compatibility data = self.model_dump() - if data.get("potential_schema"): - data["potential_schema"] = [list(item) for item in data["potential_schema"]] + if data.get("node_types"): + data["node_types"] = list(data["node_types"]) + if data.get("relationship_types"): + data["relationship_types"] = list(data["relationship_types"]) + if data.get("patterns"): + data["patterns"] = [list(item) for item in data["patterns"]] with open(file_path, "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) @@ -169,7 +199,7 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) @@ -194,7 +224,7 @@ def from_json(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the JSON schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ with open(file_path, "r") as f: try: @@ -214,7 +244,7 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the YAML schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ with open(file_path, "r") as f: try: @@ -228,7 +258,7 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self: class SchemaBuilder(Component): """ - A builder class for constructing SchemaConfig objects from given entities, + A builder class for constructing GraphSchema objects from given entities, relations, and their interrelationships defined in a potential schema. Example: @@ -237,38 +267,38 @@ class SchemaBuilder(Component): from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.pipeline import Pipeline - entities = [ - SchemaEntity( + node_types = [ + NodeType( label="PERSON", description="An individual human being.", properties=[ - SchemaProperty( + PropertyType( name="name", type="STRING", description="The name of the person" ) ], ), - SchemaEntity( + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", properties=[ - SchemaProperty( + PropertyType( name="name", type="STRING", description="The name of the organization" ) ], ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="EMPLOYED_BY", description="Indicates employment relationship." ), ] - potential_schema = [ + patterns = [ ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ] pipe = Pipeline() @@ -276,9 +306,9 @@ class SchemaBuilder(Component): pipe.add_component(schema_builder, "schema_builder") pipe_inputs = { "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, ... } @@ -287,34 +317,29 @@ class SchemaBuilder(Component): @staticmethod def create_schema_model( - entities: List[SchemaEntity], - relations: Optional[List[SchemaRelation]] = None, - potential_schema: Optional[List[Tuple[str, str, str]]] = None, - ) -> SchemaConfig: + node_types: Sequence[NodeType], + relationship_types: Optional[Sequence[RelationshipType]] = None, + patterns: Optional[Sequence[Tuple[str, str, str]]] = None, + ) -> GraphSchema: """ - Creates a SchemaConfig object from Lists of Entity and Relation objects + Creates a GraphSchema object from Lists of Entity and Relation objects and a Dictionary defining potential relationships. Args: - entities (List[SchemaEntity]): List of Entity objects. - relations (List[SchemaRelation]): List of Relation objects. - potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names. + node_types (Sequence[NodeType]): List or tuple of NodeType objects. + relationship_types (Optional[Sequence[RelationshipType]]): List or tuple of RelationshipType objects. + patterns (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). Returns: - SchemaConfig: A configured schema object. + GraphSchema: A configured schema object. """ - entity_dict = {entity.label: entity.model_dump() for entity in entities} - relation_dict = ( - {relation.label: relation.model_dump() for relation in relations} - if relations - else {} - ) - try: - return SchemaConfig( - entities=entity_dict, - relations=relation_dict, - potential_schema=potential_schema, + return GraphSchema.model_validate( + dict( + node_types=node_types, + relationship_types=relationship_types, + patterns=patterns, + ) ) except (ValidationError, SchemaValidationError) as e: raise SchemaValidationError(e) @@ -322,27 +347,27 @@ def create_schema_model( @validate_call async def run( self, - entities: List[SchemaEntity], - relations: Optional[List[SchemaRelation]] = None, - potential_schema: Optional[List[Tuple[str, str, str]]] = None, - ) -> SchemaConfig: + node_types: Sequence[NodeType], + relationship_types: Optional[Sequence[RelationshipType]] = None, + patterns: Optional[Sequence[Tuple[str, str, str]]] = None, + ) -> GraphSchema: """ - Asynchronously constructs and returns a SchemaConfig object. + Asynchronously constructs and returns a GraphSchema object. Args: - entities (List[SchemaEntity]): List of Entity objects. - relations (List[SchemaRelation]): List of Relation objects. - potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names. + node_types (Sequence[NodeType]): Sequence of NodeType objects. + relationship_types (Sequence[RelationshipType]): Sequence of RelationshipType objects. + patterns (Optional[Sequence[Tuple[str, str, str]]]): Sequence of triplets: (source_entity_label, relation_label, target_entity_label). Returns: - SchemaConfig: A configured schema object, constructed asynchronously. + GraphSchema: A configured schema object, constructed asynchronously. """ - return self.create_schema_model(entities, relations, potential_schema) + return self.create_schema_model(node_types, relationship_types, patterns) class SchemaFromTextExtractor(Component): """ - A component for constructing SchemaConfig objects from the output of an LLM after + A component for constructing GraphSchema objects from the output of an LLM after automatic schema extraction from text. """ @@ -359,15 +384,15 @@ def __init__( self._llm_params: dict[str, Any] = llm_params or {} @validate_call - async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfig: + async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ - Asynchronously extracts the schema from text and returns a SchemaConfig object. + Asynchronously extracts the schema from text and returns a GraphSchema object. Args: text (str): the text from which the schema will be inferred. examples (str): examples to guide schema extraction. Returns: - SchemaConfig: A configured schema object, extracted automatically and + GraphSchema: A configured schema object, extracted automatically and constructed asynchronously. """ prompt: str = self._prompt_template.format(text=text, examples=examples) @@ -406,60 +431,20 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi except json.JSONDecodeError as exc: raise SchemaExtractionError("LLM response is not valid JSON.") from exc - extracted_entities: List[Dict[str, Any]] = ( - extracted_schema.get("entities") or [] - ) - extracted_relations: Optional[List[Dict[str, Any]]] = extracted_schema.get( - "relations" + extracted_node_types: List[Dict[str, Any]] = ( + extracted_schema.get("node_types") or [] ) - potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( - "potential_schema" + extracted_relationship_types: Optional[List[Dict[str, Any]]] = ( + extracted_schema.get("relationship_types") ) - - try: - entities: List[SchemaEntity] = [ - SchemaEntity(**e) for e in extracted_entities - ] - relations: Optional[List[SchemaRelation]] = ( - [SchemaRelation(**r) for r in extracted_relations] - if extracted_relations - else None - ) - except ValidationError as exc: - raise SchemaValidationError( - f"Invalid schema format return from LLM: {exc}" - ) from exc - - return SchemaBuilder.create_schema_model( - entities=entities, - relations=relations, - potential_schema=potential_schema, + extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( + "patterns" ) - -def normalize_schema_dict(schema_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Normalize a user-provided schema dictionary to the canonical format expected by the pipeline. - - - Converts 'entities' and 'relations' from lists (of strings, dicts, or model objects) to dicts keyed by label. - - Ensures required keys ('entities', 'relations', 'potential_schema') are present. - - Does not mutate the input; returns a new dict. - - Args: - schema_dict (dict): The user-provided schema dictionary, possibly with lists or missing keys. - - Returns: - dict: A normalized schema dictionary with the correct structure for pipeline and Pydantic validation. - """ - norm_schema_dict = dict(schema_dict) - for key, cls in [("entities", SchemaEntity), ("relations", SchemaRelation)]: - if key in norm_schema_dict and isinstance(norm_schema_dict[key], list): - norm_schema_dict[key] = { - cls.from_text_or_dict(e).label: cls.from_text_or_dict(e).model_dump() # type: ignore[attr-defined] - for e in norm_schema_dict[key] + return GraphSchema.model_validate( + { + "node_types": extracted_node_types, + "relationship_types": extracted_relationship_types, + "patterns": extracted_patterns, } - if "relations" not in norm_schema_dict: - norm_schema_dict["relations"] = {} - if "potential_schema" not in norm_schema_dict: - norm_schema_dict["potential_schema"] = None - return norm_schema_dict + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 0df0f61e6..d8ea45fe4 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -19,10 +19,7 @@ Optional, Sequence, Union, - List, Tuple, - Dict, - cast, ) import logging import warnings @@ -44,11 +41,10 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaConfig, - SchemaEntity, - SchemaRelation, + GraphSchema, + NodeType, + RelationshipType, SchemaFromTextExtractor, - normalize_schema_dict, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -93,7 +89,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field( + schema_: Optional[Union[GraphSchema, dict[str, list[Any]]]] = Field( default=None, alias="schema" ) enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE @@ -109,14 +105,6 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) - @model_validator(mode="before") - def normalize_schema_field(cls, data: Dict[str, Any]) -> Dict[str, Any]: - # Normalize the 'schema' field if it is a dict - schema = data.get("schema") - if isinstance(schema, dict): - data["schema"] = normalize_schema_dict(schema) - return data - @model_validator(mode="after") def handle_schema_precedence(self) -> Self: """Handle schema precedence and warnings""" @@ -199,64 +187,58 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: def _process_schema_with_precedence( self, ) -> Tuple[ - List[SchemaEntity], List[SchemaRelation], Optional[List[Tuple[str, str, str]]] + Tuple[NodeType, ...], + Tuple[RelationshipType, ...], + Optional[Tuple[Tuple[str, str, str], ...]], ]: """ Process schema inputs according to precedence rules: - 1. If schema is provided as SchemaConfig object, use it + 1. If schema is provided as GraphSchema object, use it 2. If schema is provided as dictionary, extract from it 3. Otherwise, use individual schema components Returns: - Tuple of (entities, relations, potential_schema) + Tuple of (node_types, relationship_types, patterns) """ if self.schema_ is not None: # schema takes precedence over individual components - if isinstance(self.schema_, SchemaConfig): - # extract components from SchemaConfig - entity_dicts = list(self.schema_.entities.values()) - # convert dict values to SchemaEntity objects - entities = [SchemaEntity.model_validate(e) for e in entity_dicts] + if isinstance(self.schema_, GraphSchema): + # extract components from GraphSchema + node_types = self.schema_.node_types # handle case where relations could be None - if self.schema_.relations is not None: - relation_dicts = list(self.schema_.relations.values()) - relations = [ - SchemaRelation.model_validate(r) for r in relation_dicts - ] + if self.schema_.relationship_types is not None: + relationship_types = self.schema_.relationship_types else: - relations = [] + relationship_types = () - potential_schema = self.schema_.potential_schema + patterns = self.schema_.patterns else: - entities = [ - SchemaEntity.from_text_or_dict(e) - for e in cast( - Dict[str, Any], self.schema_.get("entities", {}) - ).values() - ] - relations = [ - SchemaRelation.from_text_or_dict(r) - for r in cast( - Dict[str, Any], self.schema_.get("relations", {}) - ).values() - ] - potential_schema = self.schema_.get("potential_schema") + node_types = tuple( + NodeType.from_text_or_dict(e) + for e in self.schema_.get("node_types", ()) + ) + relationship_types = tuple( + RelationshipType.from_text_or_dict(r) + for r in self.schema_.get("relationship_types", ()) + ) + ps = self.schema_.get("patterns") + patterns = tuple(ps) if ps else None else: # use individual components - entities = ( - [SchemaEntity.from_text_or_dict(e) for e in self.entities] + node_types = tuple( + [NodeType.from_text_or_dict(e) for e in self.entities] if self.entities else [] ) - relations = ( - [SchemaRelation.from_text_or_dict(r) for r in self.relations] + relationship_types = tuple( + [RelationshipType.from_text_or_dict(r) for r in self.relations] if self.relations else [] ) - potential_schema = self.potential_schema + patterns = tuple(self.potential_schema) if self.potential_schema else None - return entities, relations, potential_schema + return node_types, relationship_types, patterns def _get_run_params_for_schema(self) -> dict[str, Any]: if not self.has_user_provided_schema(): @@ -264,14 +246,14 @@ def _get_run_params_for_schema(self) -> dict[str, Any]: return {} else: # process schema components according to precedence rules - entities, relations, potential_schema = ( + node_types, relationship_types, patterns = ( self._process_schema_with_precedence() ) return { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, } def _get_extractor(self) -> EntityRelationExtractor: diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index c586a7fad..d46ddc046 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -43,7 +43,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema logger = logging.getLogger(__name__) @@ -57,18 +57,18 @@ class SimpleKGPipeline: llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. driver (neo4j.Driver): A Neo4j driver instance for database connection. embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks. - schema (Optional[Union[SchemaConfig, dict[str, list]]]): A schema configuration defining entities, + schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining entities, relations, and potential schema relationships. This is the recommended way to provide schema information. - entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either: + entities (Optional[List[Union[str, dict[str, str], NodeType]]]): DEPRECATED. A list of either: - str: entity labels - - dict: following the SchemaEntity schema, ie with label, description and properties keys + - dict: following the NodeType schema, ie with label, description and properties keys - relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): DEPRECATED. A list of either: + relations (Optional[List[Union[str, dict[str, str], RelationshipType]]]): DEPRECATED. A list of either: - str: relation label - - dict: following the SchemaRelation schema, ie with label, description and properties keys + - dict: following the RelationshipType schema, ie with label, description and properties keys potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". @@ -92,7 +92,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, - schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None, + schema: Optional[Union[GraphSchema, dict[str, list[Any]]]] = None, enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, diff --git a/src/neo4j_graphrag/experimental/pipeline/types/schema.py b/src/neo4j_graphrag/experimental/pipeline/types/schema.py index 626c99841..3bc8a7446 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/schema.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/schema.py @@ -19,7 +19,7 @@ EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] -"""Types derived from the SchemaEntity and SchemaRelation types, +"""Types derived from the NodeType and RelationshipType types, so the possible types for dict values are: - str (for label and description) - list[dict[str, str]] (for properties) diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 96bcaf8de..24de870fb 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -204,7 +204,7 @@ def format( class SchemaExtractionTemplate(PromptTemplate): DEFAULT_TEMPLATE = """ -You are a top-tier algorithm designed for extracting a labeled property graph schema in +You are a top-tier algorithm designed for extracting a labeled property graph schema in structured formats. Generate a generalized graph schema based on the input text. Identify key entity types, @@ -219,12 +219,12 @@ class SchemaExtractionTemplate(PromptTemplate): 6. Do not create entity types that aren't clearly mentioned in the text. 7. Keep your schema minimal and focused on clearly identifiable patterns in the text. -Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, +Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. Return a valid JSON object that follows this precise structure: {{ - "entities": [ + "node_types": [ {{ "label": "Person", "properties": [ @@ -236,13 +236,13 @@ class SchemaExtractionTemplate(PromptTemplate): }}, ... ], - "relations": [ + "relationship_types": [ {{ "label": "WORKS_FOR" }}, ... ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Company"], ... ] diff --git a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py index bf74470cf..89f5ae62c 100644 --- a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py @@ -33,9 +33,9 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -187,49 +187,49 @@ async def test_pipeline_builder_happy_path( pipe_inputs = { "splitter": {"text": harry_potter_text}, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Potion", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Location", properties=[ - SchemaProperty(name="address", type="STRING"), + PropertyType(name="address", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="KNOWS", ), - SchemaRelation( + RelationshipType( label="PART_OF", ), - SchemaRelation( + RelationshipType( label="LED_BY", ), - SchemaRelation( + RelationshipType( label="DRINKS", ), ], - "potential_schema": [ + "patterns": [ ("Person", "KNOWS", "Person"), ("Person", "DRINKS", "Potion"), ("Person", "PART_OF", "Organization"), @@ -356,9 +356,9 @@ async def test_pipeline_builder_failing_chunk_raise( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } with pytest.raises(LLMGenerationError): @@ -434,9 +434,9 @@ async def test_pipeline_builder_failing_chunk_do_not_raise( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } kg_builder_pipeline.get_node_by_name( @@ -575,9 +575,9 @@ async def test_pipeline_builder_two_documents( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } pipe_inputs_2 = { @@ -585,9 +585,9 @@ async def test_pipeline_builder_two_documents( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } await kg_builder_pipeline.run(pipe_inputs_1) diff --git a/tests/e2e/experimental/test_simplekgpipeline_e2e.py b/tests/e2e/experimental/test_simplekgpipeline_e2e.py index 3bd72dd41..76c532659 100644 --- a/tests/e2e/experimental/test_simplekgpipeline_e2e.py +++ b/tests/e2e/experimental/test_simplekgpipeline_e2e.py @@ -305,7 +305,7 @@ async def test_pipeline_builder_with_automatic_schema_extraction( # first call - schema extraction response LLMResponse( content="""{ - "entities": [ + "node_types": [ { "label": "Person", "description": "A character in the story", @@ -322,14 +322,14 @@ async def test_pipeline_builder_with_automatic_schema_extraction( ] } ], - "relations": [ + "relationship_types": [ { "label": "LOCATED_AT", "description": "Indicates where a person is located", "properties": [] } ], - "potential_schema": [ + "patterns": [ ["Person", "LOCATED_AT", "Location"] ] }""" diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 70c115fea..64ac0d42d 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -25,7 +25,7 @@ balance_curly_braces, fix_invalid_json, ) -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, Neo4jGraph, @@ -243,15 +243,17 @@ async def test_extractor_no_schema_enforcement() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relationship_types": [], + "patterns": [], }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -297,15 +299,17 @@ async def test_extractor_schema_enforcement_invalid_nodes() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relationship_types": [], + "patterns": [], }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -330,18 +334,20 @@ async def test_extraction_schema_enforcement_invalid_node_properties() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "INTEGER"}, - ], - } + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"}, + ], + } + ], + "relationship_types": [], + "patterns": [], }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -366,10 +372,15 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> No llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={"Person": {"label": "Person"}}, relations={}, potential_schema=[] + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + } + ], + } ) - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) result: Neo4jGraph = await extractor.run(chunks, schema=schema) @@ -392,15 +403,20 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> N llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={"LIKES": {"label": "LIKES"}}, - potential_schema=[], + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"}, + ], + } + ], + "relationship_types": [{"label": "LIKES"}], + "patterns": [], + } ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -428,19 +444,21 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - "City": { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - }, - relations={"LIVES_IN": {"label": "LIVES_IN"}}, - potential_schema=[("Person", "LIVES_IN", "City")], + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }, + { + "label": "City", + "properties": [{"name": "name", "type": "STRING"}], + }, + ], + "relationship_types": [{"label": "LIVES_IN"}], + "patterns": [("Person", "LIVES_IN", "City")], + } ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -465,20 +483,22 @@ async def test_extractor_schema_enforcement_invalid_relation_properties() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={ - "LIKES": { - "label": "LIKES", - "properties": [{"name": "strength", "type": "STRING"}], - } - }, - potential_schema=[], + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relationship_types": [ + { + "label": "LIKES", + "properties": [{"name": "strength", "type": "STRING"}], + } + ], + "patterns": [], + } ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -506,15 +526,17 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={"LIKES": {"label": "LIKES"}}, - potential_schema=[("Person", "LIKES", "Person")], + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relationship_types": [{"label": "LIKES"}], + "patterns": [("Person", "LIKES", "Person")], + } ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -539,19 +561,21 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - "City": { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - }, - relations={"LIVES_IN": {"label": "LIVES_IN"}}, - potential_schema=[("Person", "LIVES_IN", "City")], + schema = GraphSchema.model_validate( + { + "node_types": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }, + { + "label": "City", + "properties": [{"name": "name", "type": "STRING"}], + }, + ], + "relationship_types": [{"label": "LIVES_IN"}], + "patterns": [("Person", "LIVES_IN", "City")], + } ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -579,15 +603,17 @@ async def test_extractor_schema_enforcement_none_relationships_in_schema() -> No llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations=None, - potential_schema=None, + schema = GraphSchema.model_validate( + dict( + node_types=[ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + relationship_types=None, + patterns=None, + ) ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -614,15 +640,17 @@ async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> N llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={}, - potential_schema=None, + schema = GraphSchema.model_validate( + dict( + node_types=[ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + relationship_types=[], + patterns=None, + ) ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index be7bbd958..e8fc670c2 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -15,19 +15,20 @@ from __future__ import annotations import json +from typing import Tuple from unittest.mock import AsyncMock import pytest + from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, SchemaFromTextExtractor, - SchemaConfig, + GraphSchema, ) -from pydantic import ValidationError import os import tempfile import yaml @@ -37,66 +38,64 @@ @pytest.fixture -def valid_entities() -> list[SchemaEntity]: - return [ - SchemaEntity( +def valid_node_types() -> tuple[NodeType, ...]: + return ( + NodeType( label="PERSON", description="An individual human being.", properties=[ - SchemaProperty(name="birth date", type="ZONED_DATETIME"), - SchemaProperty(name="name", type="STRING"), + PropertyType(name="birth date", type="ZONED_DATETIME"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="AGE", description="Age of a person in years."), - ] + NodeType(label="AGE", description="Age of a person in years."), + ) @pytest.fixture -def valid_relations() -> list[SchemaRelation]: - return [ - SchemaRelation( +def valid_relationship_types() -> tuple[RelationshipType, ...]: + return ( + RelationshipType( label="EMPLOYED_BY", description="Indicates employment relationship.", properties=[ - SchemaProperty(name="start_time", type="LOCAL_DATETIME"), - SchemaProperty(name="end_time", type="LOCAL_DATETIME"), + PropertyType(name="start_time", type="LOCAL_DATETIME"), + PropertyType(name="end_time", type="LOCAL_DATETIME"), ], ), - SchemaRelation( + RelationshipType( label="ORGANIZED_BY", description="Indicates organization responsible for an event.", ), - SchemaRelation( + RelationshipType( label="ATTENDED_BY", description="Indicates attendance at an event." ), - ] + ) @pytest.fixture -def potential_schema() -> list[tuple[str, str, str]]: - return [ +def valid_patterns() -> tuple[tuple[str, str, str], ...]: + return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("ORGANIZATION", "ATTENDED_BY", "PERSON"), - ] + ) @pytest.fixture -def potential_schema_with_invalid_entity() -> list[tuple[str, str, str]]: - return [ +def patterns_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: + return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("NON_EXISTENT_ENTITY", "ATTENDED_BY", "PERSON"), - ] + ) @pytest.fixture -def potential_schema_with_invalid_relation() -> list[tuple[str, str, str]]: - return [ - ("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"), - ] +def patterns_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: + return (("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"),) @pytest.fixture @@ -105,196 +104,59 @@ def schema_builder() -> SchemaBuilder: @pytest.fixture -def schema_config( +def graph_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], -) -> SchemaConfig: + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], +) -> GraphSchema: return schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) def test_create_schema_model_valid_data( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: schema_instance = schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema - ) - - assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." - ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, - ] - assert ( - schema_instance.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema_instance.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema_instance.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [ - {"description": "", "name": "start_time", "type": "LOCAL_DATETIME"}, - {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, - ] - - assert schema_instance.potential_schema - assert schema_instance.potential_schema == potential_schema - -def test_create_schema_model_missing_description( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity(label="ORGANIZATION", description=""), - SchemaEntity(label="AGE", description=""), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation(label="ORGANIZED_BY", description=""), - SchemaRelation(label="ATTENDED_BY", description=""), - ] - - schema_instance = schema_builder.create_schema_model( - entities, relations, potential_schema - ) - - assert schema_instance.entities["ORGANIZATION"]["description"] == "" - assert schema_instance.entities["AGE"]["description"] == "" - assert schema_instance.relations - assert schema_instance.relations["ORGANIZED_BY"]["description"] == "" - assert schema_instance.relations["ATTENDED_BY"]["description"] == "" - - -def test_create_schema_model_empty_lists(schema_builder: SchemaBuilder) -> None: - schema_instance = schema_builder.create_schema_model([], [], []) - - assert schema_instance.entities == {} - assert schema_instance.relations == {} - assert schema_instance.potential_schema == [] - - -def test_create_schema_model_invalid_data_types( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - with pytest.raises(ValidationError): - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation( - label=456, # type: ignore - description="Indicates organization responsible for an event.", - ), - ] - - schema_builder.create_schema_model(entities, relations, potential_schema) - - -def test_create_schema_model_invalid_properties_types( - schema_builder: SchemaBuilder, - potential_schema: list[tuple[str, str, str]], -) -> None: - with pytest.raises(ValidationError): - entities = [ - SchemaEntity( - label="PERSON", - description="An individual human being.", - properties=[42, 1337], # type: ignore - ), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", - description="Indicates employment relationship.", - properties=[42, 1337], # type: ignore - ), - SchemaRelation( - label="ORGANIZED_BY", - description="Indicates organization responsible for an event.", - ), - ] - - schema_builder.create_schema_model(entities, relations, potential_schema) + assert schema_instance.node_types == valid_node_types + assert schema_instance.relationship_types == valid_relationship_types + assert schema_instance.patterns == valid_patterns @pytest.mark.asyncio async def test_run_method( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema = await schema_builder.run(valid_entities, valid_relations, potential_schema) - - assert schema.entities["PERSON"]["description"] == "An individual human being." - assert ( - schema.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." + schema = await schema_builder.run( + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema.entities["AGE"]["description"] == "Age of a person in years." - assert schema.relations - assert ( - schema.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - - assert schema.potential_schema - assert schema.potential_schema == potential_schema + assert schema.node_types == valid_node_types + assert schema.relationship_types == valid_relationship_types + assert schema.patterns == valid_patterns def test_create_schema_model_invalid_entity( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema_with_invalid_entity: list[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + patterns_with_invalid_entity: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema_with_invalid_entity + list(valid_node_types), + list(valid_relationship_types), + list(patterns_with_invalid_entity), ) assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( exc_info.value @@ -303,141 +165,64 @@ def test_create_schema_model_invalid_entity( def test_create_schema_model_invalid_relation( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema_with_invalid_relation: list[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + patterns_with_invalid_relation: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema_with_invalid_relation + list(valid_node_types), + list(valid_relationship_types), + list(patterns_with_invalid_relation), ) assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value ), "Should fail due to non-existent relation" -def test_create_schema_model_missing_properties( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - SchemaEntity(label="AGE", description="Age of a person in years."), - ] - - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation( - label="ORGANIZED_BY", - description="Indicates organization responsible for an event.", - ), - SchemaRelation( - label="ATTENDED_BY", description="Indicates attendance at an event." - ), - ] - - schema_instance = schema_builder.create_schema_model( - entities, relations, potential_schema - ) - - assert ( - schema_instance.entities["PERSON"]["properties"] == [] - ), "Expected empty properties for PERSON" - assert ( - schema_instance.entities["ORGANIZATION"]["properties"] == [] - ), "Expected empty properties for ORGANIZATION" - assert ( - schema_instance.entities["AGE"]["properties"] == [] - ), "Expected empty properties for AGE" - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["properties"] == [] - ), "Expected empty properties for EMPLOYED_BY" - assert ( - schema_instance.relations["ORGANIZED_BY"]["properties"] == [] - ), "Expected empty properties for ORGANIZED_BY" - assert ( - schema_instance.relations["ATTENDED_BY"]["properties"] == [] - ), "Expected empty properties for ATTENDED_BY" - - def test_create_schema_model_no_potential_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], ) -> None: schema_instance = schema_builder.create_schema_model( - valid_entities, valid_relations + list(valid_node_types), list(valid_relationship_types) ) - - assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." - ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, - ] - assert ( - schema_instance.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." - ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema_instance.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema_instance.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [ - {"description": "", "name": "start_time", "type": "LOCAL_DATETIME"}, - {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, - ] + assert schema_instance.node_types == valid_node_types + assert schema_instance.relationship_types == valid_relationship_types + assert schema_instance.patterns is None def test_create_schema_model_no_relations_or_potential_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], + valid_node_types: Tuple[NodeType, ...], ) -> None: - schema_instance = schema_builder.create_schema_model(valid_entities) + schema_instance = schema_builder.create_schema_model(list(valid_node_types)) - assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." - ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, - ] - assert ( - schema_instance.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." - ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." + assert len(schema_instance.node_types) == 3 + person = schema_instance.node_type_from_label("PERSON") + + assert person is not None + assert person.description == "An individual human being." + assert len(person.properties) == 2 + + org = schema_instance.node_type_from_label("ORGANIZATION") + assert org is not None + assert org.description == "A structured group of people with a common purpose." + + age = schema_instance.node_type_from_label("AGE") + assert age is not None + assert age.description == "Age of a person in years." def test_create_schema_model_missing_relations( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - potential_schema: list[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - entities=valid_entities, potential_schema=potential_schema + node_types=valid_node_types, patterns=valid_patterns ) assert "Relations must also be provided when using a potential schema." in str( exc_info.value @@ -455,7 +240,7 @@ def mock_llm() -> AsyncMock: def valid_schema_json() -> str: return """ { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -469,7 +254,7 @@ def valid_schema_json() -> str: ] } ], - "relations": [ + "relationship_types": [ { "label": "WORKS_FOR", "properties": [ @@ -477,7 +262,7 @@ def valid_schema_json() -> str: ] } ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Organization"] ] } @@ -488,7 +273,7 @@ def valid_schema_json() -> str: def invalid_schema_json() -> str: return """ { - "entities": [ + "node_types": [ { "label": "Person", }, @@ -522,16 +307,16 @@ async def test_schema_from_text_run_valid_response( assert "Sample text for extraction" in prompt_arg # verify the schema was correctly extracted - assert len(schema_config.entities) == 2 - assert "Person" in schema_config.entities - assert "Organization" in schema_config.entities + assert len(schema_config.node_types) == 2 + assert schema_config.node_type_from_label("Person") is not None + assert schema_config.node_type_from_label("Organization") is not None - assert schema_config.relations is not None - assert "WORKS_FOR" in schema_config.relations + assert schema_config.relationship_types is not None + assert schema_config.relationship_type_from_label("WORKS_FOR") is not None - assert schema_config.potential_schema is not None - assert len(schema_config.potential_schema) == 1 - assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + assert schema_config.patterns is not None + assert len(schema_config.patterns) == 1 + assert schema_config.patterns[0] == ("Person", "WORKS_FOR", "Organization") @pytest.mark.asyncio @@ -598,13 +383,13 @@ async def test_schema_from_text_llm_params( @pytest.mark.asyncio -async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: +async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") # store the schema config - schema_config.store_as_json(json_path) + graph_schema.store_as_json(json_path) # verify the file exists and has content assert os.path.exists(json_path) @@ -613,24 +398,18 @@ async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: # verify the content is valid JSON and contains expected data with open(json_path, "r") as f: data = json.load(f) - assert "entities" in data - assert "PERSON" in data["entities"] - assert "properties" in data["entities"]["PERSON"] - assert "description" in data["entities"]["PERSON"] - assert ( - data["entities"]["PERSON"]["description"] - == "An individual human being." - ) + assert "node_types" in data + assert len(data["node_types"]) == 3 @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: +async def test_schema_config_store_as_yaml(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") # Store the schema config - schema_config.store_as_yaml(yaml_path) + graph_schema.store_as_yaml(yaml_path) # Verify the file exists and has content assert os.path.exists(yaml_path) @@ -639,18 +418,12 @@ async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: # Verify the content is valid YAML and contains expected data with open(yaml_path, "r") as f: data = yaml.safe_load(f) - assert "entities" in data - assert "PERSON" in data["entities"] - assert "properties" in data["entities"]["PERSON"] - assert "description" in data["entities"]["PERSON"] - assert ( - data["entities"]["PERSON"]["description"] - == "An individual human being." - ) + assert "node_types" in data + assert len(data["node_types"]) == 3 @pytest.mark.asyncio -async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: +async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file paths with different extensions json_path = os.path.join(temp_dir, "schema.json") @@ -658,31 +431,31 @@ async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: yml_path = os.path.join(temp_dir, "schema.yml") # store the schema config in the different formats - schema_config.store_as_json(json_path) - schema_config.store_as_yaml(yaml_path) - schema_config.store_as_yaml(yml_path) + graph_schema.store_as_json(json_path) + graph_schema.store_as_yaml(yaml_path) + graph_schema.store_as_yaml(yml_path) # load using from_file which should detect the format based on extension - json_schema = SchemaConfig.from_file(json_path) - yaml_schema = SchemaConfig.from_file(yaml_path) - yml_schema = SchemaConfig.from_file(yml_path) + json_schema = GraphSchema.from_file(json_path) + yaml_schema = GraphSchema.from_file(yaml_path) + yml_schema = GraphSchema.from_file(yml_path) # simple verification that the objects were loaded correctly - assert isinstance(json_schema, SchemaConfig) - assert isinstance(yaml_schema, SchemaConfig) - assert isinstance(yml_schema, SchemaConfig) + assert isinstance(json_schema, GraphSchema) + assert isinstance(yaml_schema, GraphSchema) + assert isinstance(yml_schema, GraphSchema) # verify basic structure is intact - assert "entities" in json_schema.model_dump() - assert "entities" in yaml_schema.model_dump() - assert "entities" in yml_schema.model_dump() + assert "node_types" in json_schema.model_dump() + assert "node_types" in yaml_schema.model_dump() + assert "node_types" in yml_schema.model_dump() # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") - schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension + graph_schema.store_as_json(txt_path) # Store as JSON but with .txt extension with pytest.raises(ValueError, match="Unsupported file format"): - SchemaConfig.from_file(txt_path) + GraphSchema.from_file(txt_path) @pytest.fixture @@ -690,7 +463,7 @@ def valid_schema_json_array() -> str: return """ [ { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -704,7 +477,7 @@ def valid_schema_json_array() -> str: ] } ], - "relations": [ + "relationship_types": [ { "label": "WORKS_FOR", "properties": [ @@ -712,7 +485,7 @@ def valid_schema_json_array() -> str: ] } ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Organization"] ] } @@ -730,16 +503,16 @@ async def test_schema_from_text_run_valid_json_array( mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json_array) # run the schema extraction - schema_config = await schema_from_text.run(text="Sample text for extraction") + schema = await schema_from_text.run(text="Sample text for extraction") # verify the schema was correctly extracted from the array - assert len(schema_config.entities) == 2 - assert "Person" in schema_config.entities - assert "Organization" in schema_config.entities + assert len(schema.node_types) == 2 + assert schema.node_type_from_label("Person") is not None + assert schema.node_type_from_label("Organization") is not None - assert schema_config.relations is not None - assert "WORKS_FOR" in schema_config.relations + assert schema.relationship_types is not None + assert schema.relationship_type_from_label("WORKS_FOR") is not None - assert schema_config.potential_schema is not None - assert len(schema_config.potential_schema) == 1 - assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 8aa318cd3..8bcdd1f37 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -26,9 +26,10 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, SchemaFromTextExtractor, + GraphSchema, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -38,6 +39,10 @@ SimpleKGPipelineConfig, ) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.types.schema import ( + EntityInputType, + RelationInputType, +) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm import LLMInterface @@ -142,11 +147,9 @@ def test_simple_kg_pipeline_config_schema_run_params() -> None: potential_schema=[("Person", "KNOWS", "Person")], ) assert config._get_run_params_for_schema() == { - "entities": [SchemaEntity(label="Person")], - "relations": [SchemaRelation(label="KNOWS")], - "potential_schema": [ - ("Person", "KNOWS", "Person"), - ], + "node_types": (NodeType(label="Person"),), + "relationship_types": (RelationshipType(label="KNOWS"),), + "patterns": (("Person", "KNOWS", "Person"),), } @@ -316,3 +319,170 @@ def test_simple_kg_pipeline_config_run_params_both_file_and_text() -> None: "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." in str(excinfo) ) + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> None: + entities: list[EntityInputType] = [ + "Person", + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations: list[RelationInputType] = [ + "WORKS_FOR", + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + entities=entities, + relations=relations, + potential_schema=potential_schema, + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert patterns is not None + assert len(patterns) == 2 + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_dict() -> None: + entities = [ + "Person", + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations = [ + "WORKS_FOR", + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + schema={ + "node_types": entities, + "relationship_types": relations, + "patterns": potential_schema, + } + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert patterns is not None + assert len(patterns) == 2 + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object() -> ( + None +): + entities = [ + {"label": "Person"}, + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations = [ + {"label": "WORKS_FOR"}, + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + schema=GraphSchema.model_validate( + { + "node_types": entities, + "relationship_types": relations, + "patterns": potential_schema, + } + ) + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert patterns is not None + assert len(patterns) == 2 diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index a6d3d4c42..b4ece857a 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -19,8 +19,8 @@ import pytest from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.schema import ( - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError @@ -116,14 +116,10 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: from_pdf=True, ) - # assert kg_builder.entities == entities - # assert kg_builder.relations == relations - # assert kg_builder.potential_schema == potential_schema - file_path = "path/to/test.pdf" - internal_entities = [SchemaEntity(label=label) for label in entities] - internal_relations = [SchemaRelation(label=label) for label in relations] + internal_node_types = [NodeType(label=label) for label in entities] + internal_relationship_types = [RelationshipType(label=label) for label in relations] with patch.object( kg_builder.runner.pipeline, @@ -132,9 +128,11 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: ) as mock_run: await kg_builder.run_async(file_path=file_path) pipe_inputs = mock_run.call_args[1]["data"] - assert pipe_inputs["schema"]["entities"] == internal_entities - assert pipe_inputs["schema"]["relations"] == internal_relations - assert pipe_inputs["schema"]["potential_schema"] == potential_schema + assert pipe_inputs["schema"]["node_types"] == tuple(internal_node_types) + assert pipe_inputs["schema"]["relationship_types"] == tuple( + internal_relationship_types + ) + assert pipe_inputs["schema"]["patterns"] == tuple(potential_schema) def test_simple_kg_pipeline_on_error_invalid_value() -> None: