From 3800dbc469c27978fedc7e30a479609f48247c7d Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 13 May 2025 15:30:41 +0200 Subject: [PATCH 01/16] WIP: get rid of SchemaConfig --- .../components/entity_relation_extractor.py | 33 ++-- .../experimental/components/schema.py | 107 ++++++------- .../test_entity_relation_extractor.py | 147 +++++++++--------- .../experimental/components/test_schema.py | 28 +++- 4 files changed, 159 insertions(+), 156 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index d041d78e4..838d02605 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -25,7 +25,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema, SchemaProperty from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, @@ -209,7 +209,7 @@ def __init__( self.prompt_template = template async def extract_for_chunk( - self, schema: SchemaConfig, examples: str, chunk: TextChunk + self, schema: GraphSchema, examples: str, chunk: TextChunk ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( @@ -275,7 +275,7 @@ async def run_for_chunk( self, sem: asyncio.Semaphore, chunk: TextChunk, - schema: SchemaConfig, + schema: GraphSchema, examples: str, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> Neo4jGraph: @@ -296,7 +296,7 @@ async def run( chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, - schema: Union[SchemaConfig, None] = None, + schema: Union[GraphSchema, None] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: @@ -311,7 +311,7 @@ async def run( chunks (TextChunks): List of text chunks to extract entities and relations from. document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step. lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. - schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. + schema (GraphSchema | None): Definition of the schema to guide the LLM in its extraction. examples (str): Examples for few-shot learning in the prompt. """ lexical_graph_builder = None @@ -325,7 +325,7 @@ async def run( lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) - schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[]) + schema = schema or GraphSchema(entities=[], relations=[], potential_schema=[]) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ @@ -344,7 +344,7 @@ async def run( return graph def validate_chunk( - self, chunk_graph: Neo4jGraph, schema: SchemaConfig + self, chunk_graph: Neo4jGraph, schema: GraphSchema ) -> Neo4jGraph: """ Perform validation after entity and relation extraction: @@ -363,7 +363,7 @@ def validate_chunk( def _clean_graph( self, graph: Neo4jGraph, - schema: SchemaConfig, + schema: GraphSchema, ) -> Neo4jGraph: """ Verify that the graph conforms to the provided schema. @@ -385,7 +385,7 @@ def _clean_graph( return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) def _enforce_nodes( - self, extracted_nodes: List[Neo4jNode], schema: SchemaConfig + self, extracted_nodes: List[Neo4jNode], schema: GraphSchema ) -> List[Neo4jNode]: """ Filter extracted nodes to be conformant to the schema. @@ -400,10 +400,10 @@ def _enforce_nodes( valid_nodes = [] for node in extracted_nodes: - schema_entity = schema.entities.get(node.label) + schema_entity = schema.entity_from_label(node.label) if not schema_entity: continue - allowed_props = schema_entity.get("properties") + allowed_props = schema_entity.properties or [] if allowed_props: filtered_props = self._enforce_properties( node.properties, allowed_props @@ -426,7 +426,7 @@ def _enforce_relationships( self, extracted_relationships: List[Neo4jRelationship], filtered_nodes: List[Neo4jNode], - schema: SchemaConfig, + schema: GraphSchema, ) -> List[Neo4jRelationship]: """ Filter extracted nodes to be conformant to the schema. @@ -451,12 +451,14 @@ def _enforce_relationships( for rel in extracted_relationships: schema_relation = schema.relations.get(rel.type) if not schema_relation: + logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") continue if ( rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes ): + logger.debug(f"PRUNING:: {rel} as one of {rel.start_node_id} and {rel.end_node_id} is not in the graph") continue start_label = valid_nodes[rel.start_node_id] @@ -472,9 +474,10 @@ def _enforce_relationships( ) in potential_schema if not tuple_valid and not reverse_tuple_valid: + logger.debug(f"PRUNING:: {rel} not in the potential schema") continue - allowed_props = schema_relation.get("properties") + allowed_props = schema_relation.properties or [] if allowed_props: filtered_props = self._enforce_properties(rel.properties, allowed_props) else: @@ -493,13 +496,13 @@ def _enforce_relationships( return valid_rels def _enforce_properties( - self, properties: Dict[str, Any], valid_properties: List[Dict[str, Any]] + self, properties: Dict[str, Any], valid_properties: List[SchemaProperty] ) -> Dict[str, Any]: """ Filter properties. Keep only those that exist in schema (i.e., valid properties). """ - valid_prop_names = {prop["name"] for prop in valid_properties} + valid_prop_names = {prop.name for prop in valid_properties} return { key: value for key, value in properties.items() if key in valid_prop_names } diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 2e0641c87..8c159d87f 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pathlib import Path +import neo4j from pydantic import BaseModel, ValidationError, model_validator, validate_call from typing_extensions import Self @@ -97,20 +98,22 @@ def from_text_or_dict(cls, input: RelationInputType) -> Self: return cls.model_validate(input) -class SchemaConfig(DataModel): - """ - Represents possible relationships between entities and relations in the graph. - """ +class GraphSchema(DataModel): + entities: list[SchemaEntity] + relations: Optional[list[SchemaRelation]] = None + potential_schema: Optional[List[Tuple[str, str, str]]] = None + # indexes: list[something] = None - entities: Dict[str, Dict[str, Any]] - relations: Optional[Dict[str, Dict[str, Any]]] - potential_schema: Optional[List[Tuple[str, str, str]]] + _entity_index: dict[str, SchemaEntity] + _relation_index: dict[str, SchemaRelation] - @model_validator(mode="before") - def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: - entities = data.get("entities", {}).keys() - relations = (data.get("relations") or {}).keys() - potential_schema = data.get("potential_schema", []) + @model_validator(mode="after") + def check_schema(self) -> Self: + self._entity_index = {e.label: e for e in self.entities} + self._relation_index = {r.label: r for r in self.relations} if self.relations else {} + + relations = self.relations or [] + potential_schema = self.potential_schema or [] if potential_schema: if not relations: @@ -118,20 +121,27 @@ def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: "Relations must also be provided when using a potential schema." ) for entity1, relation, entity2 in potential_schema: - if entity1 not in entities: + if entity1 not in self._entity_index: raise SchemaValidationError( f"Entity '{entity1}' is not defined in the provided entities." ) - if relation not in relations: + if relation not in self._relation_index: raise SchemaValidationError( f"Relation '{relation}' is not defined in the provided relations." ) - if entity2 not in entities: + if entity2 not in self._entity_index: raise SchemaValidationError( f"Entity '{entity2}' is not defined in the provided entities." ) - return data + return self + + def entity_from_label(self, label: str) -> Optional[SchemaEntity]: + return self._entity_index.get(label) + + def relation_from_label(self, label: str) -> Optional[SchemaRelation]: + return self._relation_index.get(label) + def store_as_json(self, file_path: str) -> None: """ @@ -225,10 +235,14 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self: except ValidationError as e: raise SchemaValidationError(f"Schema validation failed: {e}") +class BaseSchemaBuilder(Component): + async def run(self, **kwargs: Any) -> GraphSchema: + raise NotImplementedError() + -class SchemaBuilder(Component): +class SchemaBuilder(BaseSchemaBuilder): """ - A builder class for constructing SchemaConfig objects from given entities, + A builder class for constructing GraphSchema objects from given entities, relations, and their interrelationships defined in a potential schema. Example: @@ -290,7 +304,7 @@ def create_schema_model( entities: List[SchemaEntity], relations: Optional[List[SchemaRelation]] = None, potential_schema: Optional[List[Tuple[str, str, str]]] = None, - ) -> SchemaConfig: + ) -> GraphSchema: """ Creates a SchemaConfig object from Lists of Entity and Relation objects and a Dictionary defining potential relationships. @@ -303,19 +317,12 @@ def create_schema_model( Returns: SchemaConfig: A configured schema object. """ - entity_dict = {entity.label: entity.model_dump() for entity in entities} - relation_dict = ( - {relation.label: relation.model_dump() for relation in relations} - if relations - else {} - ) - try: - return SchemaConfig( - entities=entity_dict, - relations=relation_dict, + return GraphSchema.model_validate(dict( + entities=entities, + relations=relations, potential_schema=potential_schema, - ) + )) except (ValidationError, SchemaValidationError) as e: raise SchemaValidationError(e) @@ -325,9 +332,9 @@ async def run( entities: List[SchemaEntity], relations: Optional[List[SchemaRelation]] = None, potential_schema: Optional[List[Tuple[str, str, str]]] = None, - ) -> SchemaConfig: + ) -> GraphSchema: """ - Asynchronously constructs and returns a SchemaConfig object. + Asynchronously constructs and returns a GraphSchema object. Args: entities (List[SchemaEntity]): List of Entity objects. @@ -335,7 +342,7 @@ async def run( potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names. Returns: - SchemaConfig: A configured schema object, constructed asynchronously. + GraphSchema: A configured schema object, constructed asynchronously. """ return self.create_schema_model(entities, relations, potential_schema) @@ -359,7 +366,7 @@ def __init__( self._llm_params: dict[str, Any] = llm_params or {} @validate_call - async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfig: + async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ Asynchronously extracts the schema from text and returns a SchemaConfig object. @@ -430,36 +437,16 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi f"Invalid schema format return from LLM: {exc}" ) from exc - return SchemaBuilder.create_schema_model( + return GraphSchema( entities=entities, relations=relations, potential_schema=potential_schema, ) -def normalize_schema_dict(schema_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Normalize a user-provided schema dictionary to the canonical format expected by the pipeline. +class SchemaFromGraphBuilder(BaseSchemaBuilder): + def __init__(self, driver: neo4j.Driver) -> None: + self.driver = driver - - Converts 'entities' and 'relations' from lists (of strings, dicts, or model objects) to dicts keyed by label. - - Ensures required keys ('entities', 'relations', 'potential_schema') are present. - - Does not mutate the input; returns a new dict. - - Args: - schema_dict (dict): The user-provided schema dictionary, possibly with lists or missing keys. - - Returns: - dict: A normalized schema dictionary with the correct structure for pipeline and Pydantic validation. - """ - norm_schema_dict = dict(schema_dict) - for key, cls in [("entities", SchemaEntity), ("relations", SchemaRelation)]: - if key in norm_schema_dict and isinstance(norm_schema_dict[key], list): - norm_schema_dict[key] = { - cls.from_text_or_dict(e).label: cls.from_text_or_dict(e).model_dump() # type: ignore[attr-defined] - for e in norm_schema_dict[key] - } - if "relations" not in norm_schema_dict: - norm_schema_dict["relations"] = {} - if "potential_schema" not in norm_schema_dict: - norm_schema_dict["potential_schema"] = None - return norm_schema_dict + async def run(self, **kwargs: Any) -> GraphSchema: + pass diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 70c115fea..2ed794a4c 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -25,7 +25,7 @@ balance_curly_braces, fix_invalid_json, ) -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, Neo4jGraph, @@ -243,15 +243,14 @@ async def test_extractor_no_schema_enforcement() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } + schema = GraphSchema.model_validate({ + "entities": [{ + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }], + "relations": [], + "potential_schema": [] }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -297,15 +296,14 @@ async def test_extractor_schema_enforcement_invalid_nodes() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } + schema = GraphSchema.model_validate({ + "entities": [{ + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }], + "relations": [], + "potential_schema": [] }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -330,18 +328,17 @@ async def test_extraction_schema_enforcement_invalid_node_properties() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "INTEGER"}, - ], - } + schema = GraphSchema.model_validate({ + "entities": [{ + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"} + ], + }], + "relations": [], + "potential_schema": [] }, - relations={}, - potential_schema=[], ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -366,10 +363,11 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> No llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={"Person": {"label": "Person"}}, relations={}, potential_schema=[] - ) - + schema = GraphSchema.model_validate({ + "entities": [{ + "label": "Person", + }], + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) result: Neo4jGraph = await extractor.run(chunks, schema=schema) @@ -392,16 +390,17 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> N llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={"LIKES": {"label": "LIKES"}}, - potential_schema=[], - ) + schema = GraphSchema.model_validate({ + "entities": [{ + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"} + ], + }], + "relations": [{"label": "LIKES"}], + "potential_schema": [] + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -428,20 +427,20 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { + schema = GraphSchema.model_validate({ + "entities":[ + { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], }, - "City": { + { "label": "City", "properties": [{"name": "name", "type": "STRING"}], }, - }, - relations={"LIVES_IN": {"label": "LIVES_IN"}}, - potential_schema=[("Person", "LIVES_IN", "City")], - ) + ], + "relations":[{"label": "LIVES_IN"}], + "potential_schema":[("Person", "LIVES_IN", "City")], + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -465,21 +464,21 @@ async def test_extractor_schema_enforcement_invalid_relation_properties() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { + schema = GraphSchema.model_validate({ + "entities": [ + { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } - }, - relations={ - "LIKES": { + ], + "relations":[ + { "label": "LIKES", "properties": [{"name": "strength", "type": "STRING"}], } - }, - potential_schema=[], - ) + ], + "potential_schema":[], + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -506,16 +505,16 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { + schema = GraphSchema.model_validate({ + "entities": [ + { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } - }, - relations={"LIKES": {"label": "LIKES"}}, - potential_schema=[("Person", "LIKES", "Person")], - ) + ], + "relations": [{"label": "LIKES"}], + "potential_schema": [("Person", "LIKES", "Person")], + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -539,20 +538,20 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { + schema = GraphSchema.model_validate({ + "entities": [ + { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], }, - "City": { + { "label": "City", "properties": [{"name": "name", "type": "STRING"}], }, - }, - relations={"LIVES_IN": {"label": "LIVES_IN"}}, - potential_schema=[("Person", "LIVES_IN", "City")], - ) + ], + "relations": [{"label": "LIVES_IN"}], + "potential_schema": [("Person", "LIVES_IN", "City")], + }) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index be7bbd958..7400696d3 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -415,19 +415,33 @@ def test_create_schema_model_no_relations_or_potential_schema( ) -> None: schema_instance = schema_builder.create_schema_model(valid_entities) + assert len(schema_instance.entities) == 2 + person = schema_instance.entity_from_label("PERSON") + assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." + person.description == "An individual human being." ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, + assert person.properties == [ + SchemaProperty( + name="name", + type="STRING", + description="", + ), + SchemaProperty( + name="birth date", + type="ZONED_DATETIME", + description="", + ) ] + + org = schema_instance.entity_from_label("ORGANIZATION") assert ( - schema_instance.entities["ORGANIZATION"]["description"] + org.description == "A structured group of people with a common purpose." ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." + + age = schema_instance.entity_from_label("AGE") + assert age.description == "Age of a person in years." def test_create_schema_model_missing_relations( From 54bcecca61cc6d53ff5f9768f7310b510f07d528 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 10:11:54 +0200 Subject: [PATCH 02/16] Replace SchemaConfig WIP --- docs/source/types.rst | 6 +- docs/source/user_guide_kg_builder.rst | 21 +- .../schema_builders/schema_from_text.py | 36 +-- .../components/entity_relation_extractor.py | 8 +- .../experimental/components/schema.py | 90 ++++---- .../template_pipeline/simple_kg_builder.py | 22 +- .../experimental/pipeline/kg_builder.py | 6 +- .../test_entity_relation_extractor.py | 207 ++++++++++-------- .../experimental/components/test_schema.py | 11 +- 9 files changed, 213 insertions(+), 194 deletions(-) diff --git a/docs/source/types.rst b/docs/source/types.rst index 73afcffb7..8ea5d05df 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -90,10 +90,10 @@ SchemaRelation .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaRelation -SchemaConfig -============ +GraphSchema +=========== -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaConfig +.. autoclass:: neo4j_graphrag.experimental.components.schema.GraphSchema LexicalGraphConfig =================== diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 3d02b47af..bb0733170 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -75,7 +75,7 @@ Graph Schema It is possible to guide the LLM by supplying a list of entities, relationships, and instructions on how to connect them. However, note that the extracted graph -may not fully adhere to these guidelines unless schema enforcement is enabled +may not fully adhere to these guidelines unless schema enforcement is enabled (see :ref:`Schema Enforcement Behaviour`). Entities and relationships can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, it must include a label key and can optionally include description and properties keys, @@ -419,7 +419,7 @@ within the configuration file. "neo4j_database": "myDb", "on_error": "IGNORE", "prompt_template": "...", - + "schema": { "entities": [ "Person", @@ -799,10 +799,10 @@ Here is a code block illustrating these concepts: ], ) -After validation, this schema is saved in a `SchemaConfig` object, whose dict representation is passed +After validation, this schema is saved in a `GraphSchema` object, whose dict representation is passed to the LLM. -Automatic Schema Extraction +Automatic Schema Extraction --------------------------- Instead of manually defining the schema, you can use the `SchemaFromTextExtractor` component to automatically extract a schema from your text using an LLM: @@ -826,19 +826,19 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto # Extract the schema from the text extracted_schema = await schema_extractor.run(text="Some text") -The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. +The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: .. code:: python # Save the schema to JSON or YAML files - schema_config.store_as_json("my_schema.json") - schema_config.store_as_yaml("my_schema.yaml") - + extracted_schema.store_as_json("my_schema.json") + extracted_schema.store_as_yaml("my_schema.yaml") + # Later, reload the schema from file - from neo4j_graphrag.experimental.components.schema import SchemaConfig - restored_schema = SchemaConfig.from_file("my_schema.json") # or my_schema.yaml + from neo4j_graphrag.experimental.components.schema import GraphSchema + restored_schema = GraphSchema.from_file("my_schema.json") # or my_schema.yaml Entity and Relation Extractor @@ -993,7 +993,6 @@ If more customization is needed, it is possible to subclass the `EntityRelationE from pydantic import validate_call from neo4j_graphrag.experimental.components.entity_relation_extractor import EntityRelationExtractor - from neo4j_graphrag.experimental.components.schema import SchemaConfig from neo4j_graphrag.experimental.components.types import ( Neo4jGraph, Neo4jNode, diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index 4396f3fb6..7840cbd17 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -14,7 +14,7 @@ from neo4j_graphrag.experimental.components.schema import ( SchemaFromTextExtractor, - SchemaConfig, + GraphSchema, ) from neo4j_graphrag.llm import OpenAILLM @@ -27,25 +27,25 @@ # Sample text to extract schema from - it's about a company and its employees TEXT = """ -Acme Corporation was founded in 1985 by John Smith in New York City. -The company specializes in manufacturing high-quality widgets and gadgets +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets for the consumer electronics industry. -Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to -Engineering Director in 2015. She oversees a team of 12 engineers working on -next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT and has filed 5 patents during her time at Acme. -The company expanded to international markets in 2012, opening offices in London, -Tokyo, and Berlin. Each office is managed by a regional director who reports +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports directly to the CEO, Michael Brown, who took over leadership in 2008. -Acme's most successful product, the SuperWidget X1, was launched in 2018 and -has sold over 2 million units worldwide. The product was developed by a team led +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. -The company currently employs 250 people across its 4 locations and had a revenue -of $75 million in the last fiscal year. Acme is planning to go public in 2024 +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 with an estimated valuation of $500 million. """ @@ -92,9 +92,9 @@ async def extract_and_save_schema() -> None: inferred_schema.store_as_yaml(YAML_FILE_PATH) print("\nExtracted Schema Summary:") - print(f"Entities: {list(inferred_schema.entities.keys())}") + print(f"Entities: {list(inferred_schema.entities)}") print( - f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}" + f"Relations: {list(inferred_schema.relations if inferred_schema.relations else [])}" ) if inferred_schema.potential_schema: @@ -119,11 +119,11 @@ async def main() -> None: # load schema from files print("\nLoading schemas from saved files:") - schema_from_json = SchemaConfig.from_file(JSON_FILE_PATH) - schema_from_yaml = SchemaConfig.from_file(YAML_FILE_PATH) + schema_from_json = GraphSchema.from_file(JSON_FILE_PATH) + schema_from_yaml = GraphSchema.from_file(YAML_FILE_PATH) - print(f"Entities in JSON schema: {list(schema_from_json.entities.keys())}") - print(f"Entities in YAML schema: {list(schema_from_yaml.entities.keys())}") + print(f"Entities in JSON schema: {list(schema_from_json.entities)}") + print(f"Entities in YAML schema: {list(schema_from_yaml.entities)}") if __name__ == "__main__": diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 838d02605..c07cd048f 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -325,7 +325,9 @@ async def run( lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) - schema = schema or GraphSchema(entities=[], relations=[], potential_schema=[]) + schema = schema or GraphSchema( + entities=frozenset(), relations=frozenset(), potential_schema=frozenset() + ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ @@ -458,7 +460,9 @@ def _enforce_relationships( rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes ): - logger.debug(f"PRUNING:: {rel} as one of {rel.start_node_id} and {rel.end_node_id} is not in the graph") + logger.debug( + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" + ) continue start_label = valid_nodes[rel.start_node_id] diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8c159d87f..458d005e3 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,11 +17,17 @@ import json import yaml import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, FrozenSet from pathlib import Path -import neo4j -from pydantic import BaseModel, ValidationError, model_validator, validate_call +from pydantic import ( + BaseModel, + PrivateAttr, + ValidationError, + model_validator, + validate_call, + ConfigDict, +) from typing_extensions import Self from neo4j_graphrag.exceptions import ( @@ -99,21 +105,26 @@ def from_text_or_dict(cls, input: RelationInputType) -> Self: class GraphSchema(DataModel): - entities: list[SchemaEntity] - relations: Optional[list[SchemaRelation]] = None - potential_schema: Optional[List[Tuple[str, str, str]]] = None - # indexes: list[something] = None + entities: FrozenSet[SchemaEntity] + relations: Optional[FrozenSet[SchemaRelation]] = None + potential_schema: Optional[FrozenSet[Tuple[str, str, str]]] = None + + _entity_index: dict[str, SchemaEntity] = PrivateAttr() + _relation_index: dict[str, SchemaRelation] = PrivateAttr() - _entity_index: dict[str, SchemaEntity] - _relation_index: dict[str, SchemaRelation] + model_config = ConfigDict( + frozen=True, + ) @model_validator(mode="after") def check_schema(self) -> Self: self._entity_index = {e.label: e for e in self.entities} - self._relation_index = {r.label: r for r in self.relations} if self.relations else {} + self._relation_index = ( + {r.label: r for r in self.relations} if self.relations else {} + ) - relations = self.relations or [] - potential_schema = self.potential_schema or [] + relations = self.relations or frozenset() + potential_schema = self.potential_schema or frozenset() if potential_schema: if not relations: @@ -142,7 +153,6 @@ def entity_from_label(self, label: str) -> Optional[SchemaEntity]: def relation_from_label(self, label: str) -> Optional[SchemaRelation]: return self._relation_index.get(label) - def store_as_json(self, file_path: str) -> None: """ Save the schema configuration to a JSON file. @@ -179,7 +189,7 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) @@ -204,7 +214,7 @@ def from_json(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the JSON schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ with open(file_path, "r") as f: try: @@ -224,7 +234,7 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self: file_path (Union[str, Path]): The path to the YAML schema configuration file. Returns: - SchemaConfig: The loaded schema configuration. + GraphSchema: The loaded schema configuration. """ with open(file_path, "r") as f: try: @@ -235,12 +245,8 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self: except ValidationError as e: raise SchemaValidationError(f"Schema validation failed: {e}") -class BaseSchemaBuilder(Component): - async def run(self, **kwargs: Any) -> GraphSchema: - raise NotImplementedError() - -class SchemaBuilder(BaseSchemaBuilder): +class SchemaBuilder(Component): """ A builder class for constructing GraphSchema objects from given entities, relations, and their interrelationships defined in a potential schema. @@ -306,23 +312,25 @@ def create_schema_model( potential_schema: Optional[List[Tuple[str, str, str]]] = None, ) -> GraphSchema: """ - Creates a SchemaConfig object from Lists of Entity and Relation objects + Creates a GraphSchema object from Lists of Entity and Relation objects and a Dictionary defining potential relationships. Args: entities (List[SchemaEntity]): List of Entity objects. - relations (List[SchemaRelation]): List of Relation objects. - potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names. + relations (Optional[List[SchemaRelation]]): List of Relation objects. + potential_schema (Optional[List[Tuple[str, str, str]]]): Dictionary mapping entity names to Lists of relation names. Returns: - SchemaConfig: A configured schema object. + GraphSchema: A configured schema object. """ try: - return GraphSchema.model_validate(dict( - entities=entities, - relations=relations, - potential_schema=potential_schema, - )) + return GraphSchema.model_validate( + dict( + entities=entities, + relations=relations, + potential_schema=potential_schema, + ) + ) except (ValidationError, SchemaValidationError) as e: raise SchemaValidationError(e) @@ -368,13 +376,13 @@ def __init__( @validate_call async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ - Asynchronously extracts the schema from text and returns a SchemaConfig object. + Asynchronously extracts the schema from text and returns a GraphSchema object. Args: text (str): the text from which the schema will be inferred. examples (str): examples to guide schema extraction. Returns: - SchemaConfig: A configured schema object, extracted automatically and + GraphSchema: A configured schema object, extracted automatically and constructed asynchronously. """ prompt: str = self._prompt_template.format(text=text, examples=examples) @@ -437,16 +445,10 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema f"Invalid schema format return from LLM: {exc}" ) from exc - return GraphSchema( - entities=entities, - relations=relations, - potential_schema=potential_schema, + return GraphSchema.model_validate( + { + "entities": entities, + "relations": relations, + "potential_schema": potential_schema, + } ) - - -class SchemaFromGraphBuilder(BaseSchemaBuilder): - def __init__(self, driver: neo4j.Driver) -> None: - self.driver = driver - - async def run(self, **kwargs: Any) -> GraphSchema: - pass diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 0df0f61e6..109ef793f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -44,11 +44,10 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaConfig, + GraphSchema, SchemaEntity, SchemaRelation, SchemaFromTextExtractor, - normalize_schema_dict, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -93,7 +92,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field( + schema_: Optional[Union[GraphSchema, dict[str, list[Any]]]] = Field( default=None, alias="schema" ) enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE @@ -203,7 +202,7 @@ def _process_schema_with_precedence( ]: """ Process schema inputs according to precedence rules: - 1. If schema is provided as SchemaConfig object, use it + 1. If schema is provided as GraphSchema object, use it 2. If schema is provided as dictionary, extract from it 3. Otherwise, use individual schema components @@ -212,22 +211,17 @@ def _process_schema_with_precedence( """ if self.schema_ is not None: # schema takes precedence over individual components - if isinstance(self.schema_, SchemaConfig): - # extract components from SchemaConfig - entity_dicts = list(self.schema_.entities.values()) - # convert dict values to SchemaEntity objects - entities = [SchemaEntity.model_validate(e) for e in entity_dicts] + if isinstance(self.schema_, GraphSchema): + # extract components from GraphSchema + entities = list(self.schema_.entities) # handle case where relations could be None if self.schema_.relations is not None: - relation_dicts = list(self.schema_.relations.values()) - relations = [ - SchemaRelation.model_validate(r) for r in relation_dicts - ] + relations = list(self.schema_.relations) else: relations = [] - potential_schema = self.schema_.potential_schema + potential_schema = list(self.schema_.potential_schema) else: entities = [ SchemaEntity.from_text_or_dict(e) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index c586a7fad..924670b90 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -43,7 +43,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.experimental.components.schema import SchemaConfig +from neo4j_graphrag.experimental.components.schema import GraphSchema logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ class SimpleKGPipeline: llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. driver (neo4j.Driver): A Neo4j driver instance for database connection. embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks. - schema (Optional[Union[SchemaConfig, dict[str, list]]]): A schema configuration defining entities, + schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining entities, relations, and potential schema relationships. This is the recommended way to provide schema information. entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either: @@ -92,7 +92,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, - schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None, + schema: Optional[Union[GraphSchema, dict[str, list[Any]]]] = None, enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 2ed794a4c..f088334da 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -243,13 +243,16 @@ async def test_extractor_no_schema_enforcement() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE ) - schema = GraphSchema.model_validate({ - "entities": [{ - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }], - "relations": [], - "potential_schema": [] + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relations": [], + "potential_schema": [], }, ) @@ -296,13 +299,16 @@ async def test_extractor_schema_enforcement_invalid_nodes() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [{ - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }], - "relations": [], - "potential_schema": [] + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relations": [], + "potential_schema": [], }, ) @@ -328,16 +334,19 @@ async def test_extraction_schema_enforcement_invalid_node_properties() -> None: llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [{ - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"} + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"}, + ], + } ], - }], - "relations": [], - "potential_schema": [] + "relations": [], + "potential_schema": [], }, ) @@ -363,11 +372,15 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> No llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [{ - "label": "Person", - }], - }) + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + } + ], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) result: Neo4jGraph = await extractor.run(chunks, schema=schema) @@ -390,17 +403,21 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> N llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [{ - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"} + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "STRING"}, + ], + } ], - }], - "relations": [{"label": "LIKES"}], - "potential_schema": [] - }) + "relations": [{"label": "LIKES"}], + "potential_schema": [], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -427,20 +444,22 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities":[ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relations":[{"label": "LIVES_IN"}], - "potential_schema":[("Person", "LIVES_IN", "City")], - }) + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }, + { + "label": "City", + "properties": [{"name": "name", "type": "STRING"}], + }, + ], + "relations": [{"label": "LIVES_IN"}], + "potential_schema": [("Person", "LIVES_IN", "City")], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -464,21 +483,23 @@ async def test_extractor_schema_enforcement_invalid_relation_properties() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relations":[ - { - "label": "LIKES", - "properties": [{"name": "strength", "type": "STRING"}], - } - ], - "potential_schema":[], - }) + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relations": [ + { + "label": "LIKES", + "properties": [{"name": "strength", "type": "STRING"}], + } + ], + "potential_schema": [], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -505,16 +526,18 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relations": [{"label": "LIKES"}], - "potential_schema": [("Person", "LIKES", "Person")], - }) + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + "relations": [{"label": "LIKES"}], + "potential_schema": [("Person", "LIKES", "Person")], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -538,20 +561,22 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = GraphSchema.model_validate({ - "entities": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relations": [{"label": "LIVES_IN"}], - "potential_schema": [("Person", "LIVES_IN", "City")], - }) + schema = GraphSchema.model_validate( + { + "entities": [ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + }, + { + "label": "City", + "properties": [{"name": "name", "type": "STRING"}], + }, + ], + "relations": [{"label": "LIVES_IN"}], + "potential_schema": [("Person", "LIVES_IN", "City")], + } + ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 7400696d3..6ba9ef40a 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -418,9 +418,7 @@ def test_create_schema_model_no_relations_or_potential_schema( assert len(schema_instance.entities) == 2 person = schema_instance.entity_from_label("PERSON") - assert ( - person.description == "An individual human being." - ) + assert person.description == "An individual human being." assert person.properties == [ SchemaProperty( name="name", @@ -431,14 +429,11 @@ def test_create_schema_model_no_relations_or_potential_schema( name="birth date", type="ZONED_DATETIME", description="", - ) + ), ] org = schema_instance.entity_from_label("ORGANIZATION") - assert ( - org.description - == "A structured group of people with a common purpose." - ) + assert org.description == "A structured group of people with a common purpose." age = schema_instance.entity_from_label("AGE") assert age.description == "Age of a person in years." From 70943fdadb85da0de7d3f950acbd0440fd5cdc57 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 10:51:07 +0200 Subject: [PATCH 03/16] Update tests - postpone frozen schema --- .../experimental/components/schema.py | 42 +-- .../template_pipeline/simple_kg_builder.py | 8 - .../experimental/components/test_schema.py | 315 ++---------------- 3 files changed, 49 insertions(+), 316 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 458d005e3..cbd7dd12a 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,7 +17,7 @@ import json import yaml import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, FrozenSet +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pathlib import Path from pydantic import ( @@ -67,6 +67,10 @@ class SchemaProperty(BaseModel): ] description: str = "" + model_config = ConfigDict( + frozen=True, + ) + class SchemaEntity(BaseModel): """ @@ -75,7 +79,7 @@ class SchemaEntity(BaseModel): label: str description: str = "" - properties: List[SchemaProperty] = [] + properties: list[SchemaProperty] = [] @classmethod def from_text_or_dict(cls, input: EntityInputType) -> Self: @@ -93,7 +97,7 @@ class SchemaRelation(BaseModel): label: str description: str = "" - properties: List[SchemaProperty] = [] + properties: list[SchemaProperty] = [] @classmethod def from_text_or_dict(cls, input: RelationInputType) -> Self: @@ -105,17 +109,13 @@ def from_text_or_dict(cls, input: RelationInputType) -> Self: class GraphSchema(DataModel): - entities: FrozenSet[SchemaEntity] - relations: Optional[FrozenSet[SchemaRelation]] = None - potential_schema: Optional[FrozenSet[Tuple[str, str, str]]] = None + entities: list[SchemaEntity] + relations: Optional[list[SchemaRelation]] = None + potential_schema: Optional[list[Tuple[str, str, str]]] = None _entity_index: dict[str, SchemaEntity] = PrivateAttr() _relation_index: dict[str, SchemaRelation] = PrivateAttr() - model_config = ConfigDict( - frozen=True, - ) - @model_validator(mode="after") def check_schema(self) -> Self: self._entity_index = {e.label: e for e in self.entities} @@ -123,8 +123,8 @@ def check_schema(self) -> Self: {r.label: r for r in self.relations} if self.relations else {} ) - relations = self.relations or frozenset() - potential_schema = self.potential_schema or frozenset() + relations = self.relations or [] + potential_schema = self.potential_schema or [] if potential_schema: if not relations: @@ -431,24 +431,10 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "potential_schema" ) - try: - entities: List[SchemaEntity] = [ - SchemaEntity(**e) for e in extracted_entities - ] - relations: Optional[List[SchemaRelation]] = ( - [SchemaRelation(**r) for r in extracted_relations] - if extracted_relations - else None - ) - except ValidationError as exc: - raise SchemaValidationError( - f"Invalid schema format return from LLM: {exc}" - ) from exc - return GraphSchema.model_validate( { - "entities": entities, - "relations": relations, + "entities": extracted_entities, + "relations": extracted_relations, "potential_schema": potential_schema, } ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 109ef793f..04c7c3bc9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -108,14 +108,6 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) - @model_validator(mode="before") - def normalize_schema_field(cls, data: Dict[str, Any]) -> Dict[str, Any]: - # Normalize the 'schema' field if it is a dict - schema = data.get("schema") - if isinstance(schema, dict): - data["schema"] = normalize_schema_dict(schema) - return data - @model_validator(mode="after") def handle_schema_precedence(self) -> Self: """Handle schema precedence and warnings""" diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 6ba9ef40a..b02ff3ee8 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -25,9 +25,8 @@ SchemaProperty, SchemaRelation, SchemaFromTextExtractor, - SchemaConfig, + GraphSchema, ) -from pydantic import ValidationError import os import tempfile import yaml @@ -110,7 +109,7 @@ def schema_config( valid_entities: list[SchemaEntity], valid_relations: list[SchemaRelation], potential_schema: list[tuple[str, str, str]], -) -> SchemaConfig: +) -> GraphSchema: return schema_builder.create_schema_model( valid_entities, valid_relations, potential_schema ) @@ -126,132 +125,11 @@ def test_create_schema_model_valid_data( valid_entities, valid_relations, potential_schema ) - assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." - ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, - ] - assert ( - schema_instance.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." - ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema_instance.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema_instance.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [ - {"description": "", "name": "start_time", "type": "LOCAL_DATETIME"}, - {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, - ] - - assert schema_instance.potential_schema + assert schema_instance.entities == valid_entities + assert schema_instance.relations == valid_relations assert schema_instance.potential_schema == potential_schema -def test_create_schema_model_missing_description( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity(label="ORGANIZATION", description=""), - SchemaEntity(label="AGE", description=""), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation(label="ORGANIZED_BY", description=""), - SchemaRelation(label="ATTENDED_BY", description=""), - ] - - schema_instance = schema_builder.create_schema_model( - entities, relations, potential_schema - ) - - assert schema_instance.entities["ORGANIZATION"]["description"] == "" - assert schema_instance.entities["AGE"]["description"] == "" - assert schema_instance.relations - assert schema_instance.relations["ORGANIZED_BY"]["description"] == "" - assert schema_instance.relations["ATTENDED_BY"]["description"] == "" - - -def test_create_schema_model_empty_lists(schema_builder: SchemaBuilder) -> None: - schema_instance = schema_builder.create_schema_model([], [], []) - - assert schema_instance.entities == {} - assert schema_instance.relations == {} - assert schema_instance.potential_schema == [] - - -def test_create_schema_model_invalid_data_types( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - with pytest.raises(ValidationError): - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation( - label=456, # type: ignore - description="Indicates organization responsible for an event.", - ), - ] - - schema_builder.create_schema_model(entities, relations, potential_schema) - - -def test_create_schema_model_invalid_properties_types( - schema_builder: SchemaBuilder, - potential_schema: list[tuple[str, str, str]], -) -> None: - with pytest.raises(ValidationError): - entities = [ - SchemaEntity( - label="PERSON", - description="An individual human being.", - properties=[42, 1337], # type: ignore - ), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - ] - relations = [ - SchemaRelation( - label="EMPLOYED_BY", - description="Indicates employment relationship.", - properties=[42, 1337], # type: ignore - ), - SchemaRelation( - label="ORGANIZED_BY", - description="Indicates organization responsible for an event.", - ), - ] - - schema_builder.create_schema_model(entities, relations, potential_schema) - - @pytest.mark.asyncio async def test_run_method( schema_builder: SchemaBuilder, @@ -261,28 +139,8 @@ async def test_run_method( ) -> None: schema = await schema_builder.run(valid_entities, valid_relations, potential_schema) - assert schema.entities["PERSON"]["description"] == "An individual human being." - assert ( - schema.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." - ) - assert schema.entities["AGE"]["description"] == "Age of a person in years." - - assert schema.relations - assert ( - schema.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - - assert schema.potential_schema + assert schema.entities == valid_entities + assert schema.relations == valid_relations assert schema.potential_schema == potential_schema @@ -316,57 +174,6 @@ def test_create_schema_model_invalid_relation( ), "Should fail due to non-existent relation" -def test_create_schema_model_missing_properties( - schema_builder: SchemaBuilder, potential_schema: list[tuple[str, str, str]] -) -> None: - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( - label="ORGANIZATION", - description="A structured group of people with a common purpose.", - ), - SchemaEntity(label="AGE", description="Age of a person in years."), - ] - - relations = [ - SchemaRelation( - label="EMPLOYED_BY", description="Indicates employment relationship." - ), - SchemaRelation( - label="ORGANIZED_BY", - description="Indicates organization responsible for an event.", - ), - SchemaRelation( - label="ATTENDED_BY", description="Indicates attendance at an event." - ), - ] - - schema_instance = schema_builder.create_schema_model( - entities, relations, potential_schema - ) - - assert ( - schema_instance.entities["PERSON"]["properties"] == [] - ), "Expected empty properties for PERSON" - assert ( - schema_instance.entities["ORGANIZATION"]["properties"] == [] - ), "Expected empty properties for ORGANIZATION" - assert ( - schema_instance.entities["AGE"]["properties"] == [] - ), "Expected empty properties for AGE" - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["properties"] == [] - ), "Expected empty properties for EMPLOYED_BY" - assert ( - schema_instance.relations["ORGANIZED_BY"]["properties"] == [] - ), "Expected empty properties for ORGANIZED_BY" - assert ( - schema_instance.relations["ATTENDED_BY"]["properties"] == [] - ), "Expected empty properties for ATTENDED_BY" - - def test_create_schema_model_no_potential_schema( schema_builder: SchemaBuilder, valid_entities: list[SchemaEntity], @@ -375,38 +182,9 @@ def test_create_schema_model_no_potential_schema( schema_instance = schema_builder.create_schema_model( valid_entities, valid_relations ) - - assert ( - schema_instance.entities["PERSON"]["description"] - == "An individual human being." - ) - assert schema_instance.entities["PERSON"]["properties"] == [ - {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, - {"description": "", "name": "name", "type": "STRING"}, - ] - assert ( - schema_instance.entities["ORGANIZATION"]["description"] - == "A structured group of people with a common purpose." - ) - assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." - - assert schema_instance.relations - assert ( - schema_instance.relations["EMPLOYED_BY"]["description"] - == "Indicates employment relationship." - ) - assert ( - schema_instance.relations["ORGANIZED_BY"]["description"] - == "Indicates organization responsible for an event." - ) - assert ( - schema_instance.relations["ATTENDED_BY"]["description"] - == "Indicates attendance at an event." - ) - assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [ - {"description": "", "name": "start_time", "type": "LOCAL_DATETIME"}, - {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, - ] + assert schema_instance.entities == valid_entities + assert schema_instance.relations == valid_relations + assert schema_instance.potential_schema is None def test_create_schema_model_no_relations_or_potential_schema( @@ -415,22 +193,11 @@ def test_create_schema_model_no_relations_or_potential_schema( ) -> None: schema_instance = schema_builder.create_schema_model(valid_entities) - assert len(schema_instance.entities) == 2 + assert len(schema_instance.entities) == 3 person = schema_instance.entity_from_label("PERSON") assert person.description == "An individual human being." - assert person.properties == [ - SchemaProperty( - name="name", - type="STRING", - description="", - ), - SchemaProperty( - name="birth date", - type="ZONED_DATETIME", - description="", - ), - ] + assert len(person.properties) == 2 org = schema_instance.entity_from_label("ORGANIZATION") assert org.description == "A structured group of people with a common purpose." @@ -532,11 +299,11 @@ async def test_schema_from_text_run_valid_response( # verify the schema was correctly extracted assert len(schema_config.entities) == 2 - assert "Person" in schema_config.entities - assert "Organization" in schema_config.entities + assert schema_config.entity_from_label("Person") is not None + assert schema_config.entity_from_label("Organization") is not None assert schema_config.relations is not None - assert "WORKS_FOR" in schema_config.relations + assert schema_config.relation_from_label("WORKS_FOR") is not None assert schema_config.potential_schema is not None assert len(schema_config.potential_schema) == 1 @@ -607,7 +374,7 @@ async def test_schema_from_text_llm_params( @pytest.mark.asyncio -async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: +async def test_schema_config_store_as_json(schema_config: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") @@ -623,17 +390,11 @@ async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: with open(json_path, "r") as f: data = json.load(f) assert "entities" in data - assert "PERSON" in data["entities"] - assert "properties" in data["entities"]["PERSON"] - assert "description" in data["entities"]["PERSON"] - assert ( - data["entities"]["PERSON"]["description"] - == "An individual human being." - ) + assert len(data["entities"]) == 3 @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: +async def test_schema_config_store_as_yaml(schema_config: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") @@ -649,17 +410,11 @@ async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: with open(yaml_path, "r") as f: data = yaml.safe_load(f) assert "entities" in data - assert "PERSON" in data["entities"] - assert "properties" in data["entities"]["PERSON"] - assert "description" in data["entities"]["PERSON"] - assert ( - data["entities"]["PERSON"]["description"] - == "An individual human being." - ) + assert len(data["entities"]) == 3 @pytest.mark.asyncio -async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: +async def test_schema_config_from_file(schema_config: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file paths with different extensions json_path = os.path.join(temp_dir, "schema.json") @@ -672,14 +427,14 @@ async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: schema_config.store_as_yaml(yml_path) # load using from_file which should detect the format based on extension - json_schema = SchemaConfig.from_file(json_path) - yaml_schema = SchemaConfig.from_file(yaml_path) - yml_schema = SchemaConfig.from_file(yml_path) + json_schema = GraphSchema.from_file(json_path) + yaml_schema = GraphSchema.from_file(yaml_path) + yml_schema = GraphSchema.from_file(yml_path) # simple verification that the objects were loaded correctly - assert isinstance(json_schema, SchemaConfig) - assert isinstance(yaml_schema, SchemaConfig) - assert isinstance(yml_schema, SchemaConfig) + assert isinstance(json_schema, GraphSchema) + assert isinstance(yaml_schema, GraphSchema) + assert isinstance(yml_schema, GraphSchema) # verify basic structure is intact assert "entities" in json_schema.model_dump() @@ -691,7 +446,7 @@ async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension with pytest.raises(ValueError, match="Unsupported file format"): - SchemaConfig.from_file(txt_path) + GraphSchema.from_file(txt_path) @pytest.fixture @@ -739,16 +494,16 @@ async def test_schema_from_text_run_valid_json_array( mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json_array) # run the schema extraction - schema_config = await schema_from_text.run(text="Sample text for extraction") + schema = await schema_from_text.run(text="Sample text for extraction") # verify the schema was correctly extracted from the array - assert len(schema_config.entities) == 2 - assert "Person" in schema_config.entities - assert "Organization" in schema_config.entities + assert len(schema.entities) == 2 + assert schema.entity_from_label("Person") is not None + assert schema.entity_from_label("Organization") is not None - assert schema_config.relations is not None - assert "WORKS_FOR" in schema_config.relations + assert schema.relations is not None + assert schema.relation_from_label("WORKS_FOR") is not None - assert schema_config.potential_schema is not None - assert len(schema_config.potential_schema) == 1 - assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + assert schema.potential_schema is not None + assert len(schema.potential_schema) == 1 + assert schema.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") From a54ddcfaebdf1a7833e2a1ed8e74e7105f1b2486 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:10:37 +0200 Subject: [PATCH 04/16] Use tuples instead of lists for immutability (and make the _*_index always consistent) --- .../components/entity_relation_extractor.py | 2 +- .../experimental/components/schema.py | 32 +++-- .../experimental/components/test_schema.py | 115 ++++++++++-------- 3 files changed, 84 insertions(+), 65 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index c07cd048f..c28f0dfb4 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -326,7 +326,7 @@ async def run( elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or GraphSchema( - entities=frozenset(), relations=frozenset(), potential_schema=frozenset() + entities=(), relations=(), potential_schema=() ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index cbd7dd12a..8dcaec7b7 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,7 +17,7 @@ import json import yaml import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence from pathlib import Path from pydantic import ( @@ -109,13 +109,17 @@ def from_text_or_dict(cls, input: RelationInputType) -> Self: class GraphSchema(DataModel): - entities: list[SchemaEntity] - relations: Optional[list[SchemaRelation]] = None - potential_schema: Optional[list[Tuple[str, str, str]]] = None + entities: Tuple[SchemaEntity, ...] + relations: Optional[Tuple[SchemaRelation, ...]] = None + potential_schema: Optional[Tuple[Tuple[str, str, str], ...]] = None _entity_index: dict[str, SchemaEntity] = PrivateAttr() _relation_index: dict[str, SchemaRelation] = PrivateAttr() + model_config = ConfigDict( + frozen=True, + ) + @model_validator(mode="after") def check_schema(self) -> Self: self._entity_index = {e.label: e for e in self.entities} @@ -123,8 +127,8 @@ def check_schema(self) -> Self: {r.label: r for r in self.relations} if self.relations else {} ) - relations = self.relations or [] - potential_schema = self.potential_schema or [] + relations = self.relations or tuple() + potential_schema = self.potential_schema or tuple() if potential_schema: if not relations: @@ -172,6 +176,10 @@ def store_as_yaml(self, file_path: str) -> None: """ # create a copy of the data and convert tuples to lists for YAML compatibility data = self.model_dump() + if data.get("entities"): + data["entities"] = list(data["entities"]) + if data.get("relations"): + data["relations"] = list(data["relations"]) if data.get("potential_schema"): data["potential_schema"] = [list(item) for item in data["potential_schema"]] @@ -307,9 +315,9 @@ class SchemaBuilder(Component): @staticmethod def create_schema_model( - entities: List[SchemaEntity], - relations: Optional[List[SchemaRelation]] = None, - potential_schema: Optional[List[Tuple[str, str, str]]] = None, + entities: Sequence[SchemaEntity], + relations: Optional[Sequence[SchemaRelation]] = None, + potential_schema: Optional[Sequence[Tuple[str, str, str]]] = None, ) -> GraphSchema: """ Creates a GraphSchema object from Lists of Entity and Relation objects @@ -337,9 +345,9 @@ def create_schema_model( @validate_call async def run( self, - entities: List[SchemaEntity], - relations: Optional[List[SchemaRelation]] = None, - potential_schema: Optional[List[Tuple[str, str, str]]] = None, + entities: Sequence[SchemaEntity], + relations: Optional[Sequence[SchemaRelation]] = None, + potential_schema: Optional[Sequence[Tuple[str, str, str]]] = None, ) -> GraphSchema: """ Asynchronously constructs and returns a GraphSchema object. diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index b02ff3ee8..5761ce088 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -18,6 +18,8 @@ from unittest.mock import AsyncMock import pytest +from pydantic import ValidationError + from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, @@ -36,8 +38,8 @@ @pytest.fixture -def valid_entities() -> list[SchemaEntity]: - return [ +def valid_entities() -> tuple[SchemaEntity, ...]: + return ( SchemaEntity( label="PERSON", description="An individual human being.", @@ -51,12 +53,12 @@ def valid_entities() -> list[SchemaEntity]: description="A structured group of people with a common purpose.", ), SchemaEntity(label="AGE", description="Age of a person in years."), - ] + ) @pytest.fixture -def valid_relations() -> list[SchemaRelation]: - return [ +def valid_relations() -> tuple[SchemaRelation, ...]: + return ( SchemaRelation( label="EMPLOYED_BY", description="Indicates employment relationship.", @@ -72,30 +74,30 @@ def valid_relations() -> list[SchemaRelation]: SchemaRelation( label="ATTENDED_BY", description="Indicates attendance at an event." ), - ] + ) @pytest.fixture -def potential_schema() -> list[tuple[str, str, str]]: - return [ +def potential_schema() -> tuple[tuple[str, str, str], ...]: + return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("ORGANIZATION", "ATTENDED_BY", "PERSON"), - ] + ) @pytest.fixture -def potential_schema_with_invalid_entity() -> list[tuple[str, str, str]]: - return [ +def potential_schema_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: + return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("NON_EXISTENT_ENTITY", "ATTENDED_BY", "PERSON"), - ] + ) @pytest.fixture -def potential_schema_with_invalid_relation() -> list[tuple[str, str, str]]: - return [ +def potential_schema_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: + return ( ("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"), - ] + ) @pytest.fixture @@ -104,25 +106,25 @@ def schema_builder() -> SchemaBuilder: @pytest.fixture -def schema_config( +def graph_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity], + valid_relations: tuple[SchemaRelation], + potential_schema: tuple[tuple[str, str, str]], ) -> GraphSchema: return schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema + list(valid_entities), list(valid_relations), list(potential_schema) ) def test_create_schema_model_valid_data( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity], + valid_relations: tuple[SchemaRelation], + potential_schema: tuple[tuple[str, str, str]], ) -> None: schema_instance = schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema + list(valid_entities), list(valid_relations), list(potential_schema) ) assert schema_instance.entities == valid_entities @@ -133,11 +135,15 @@ def test_create_schema_model_valid_data( @pytest.mark.asyncio async def test_run_method( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity, ...], + valid_relations: tuple[SchemaRelation, ...], + potential_schema: tuple[tuple[str, str, str], ...], ) -> None: - schema = await schema_builder.run(valid_entities, valid_relations, potential_schema) + schema = await schema_builder.run( + list(valid_entities), + list(valid_relations), + list(potential_schema) + ) assert schema.entities == valid_entities assert schema.relations == valid_relations @@ -146,13 +152,15 @@ async def test_run_method( def test_create_schema_model_invalid_entity( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema_with_invalid_entity: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity, ...], + valid_relations: tuple[SchemaRelation, ...], + potential_schema_with_invalid_entity: tuple[tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema_with_invalid_entity + list(valid_entities), + list(valid_relations), + list(potential_schema_with_invalid_entity), ) assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( exc_info.value @@ -161,13 +169,13 @@ def test_create_schema_model_invalid_entity( def test_create_schema_model_invalid_relation( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], - potential_schema_with_invalid_relation: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity, ...], + valid_relations: tuple[SchemaRelation, ...], + potential_schema_with_invalid_relation: tuple[tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - valid_entities, valid_relations, potential_schema_with_invalid_relation + list(valid_entities), list(valid_relations), list(potential_schema_with_invalid_relation) ) assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value @@ -176,11 +184,11 @@ def test_create_schema_model_invalid_relation( def test_create_schema_model_no_potential_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - valid_relations: list[SchemaRelation], + valid_entities: tuple[SchemaEntity, ...], + valid_relations: tuple[SchemaRelation, ...], ) -> None: schema_instance = schema_builder.create_schema_model( - valid_entities, valid_relations + list(valid_entities), list(valid_relations) ) assert schema_instance.entities == valid_entities assert schema_instance.relations == valid_relations @@ -189,27 +197,30 @@ def test_create_schema_model_no_potential_schema( def test_create_schema_model_no_relations_or_potential_schema( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], + valid_entities: tuple[SchemaEntity, ...], ) -> None: - schema_instance = schema_builder.create_schema_model(valid_entities) + schema_instance = schema_builder.create_schema_model(list(valid_entities)) assert len(schema_instance.entities) == 3 person = schema_instance.entity_from_label("PERSON") + assert person is not None assert person.description == "An individual human being." assert len(person.properties) == 2 org = schema_instance.entity_from_label("ORGANIZATION") + assert org is not None assert org.description == "A structured group of people with a common purpose." age = schema_instance.entity_from_label("AGE") + assert age is not None assert age.description == "Age of a person in years." def test_create_schema_model_missing_relations( schema_builder: SchemaBuilder, - valid_entities: list[SchemaEntity], - potential_schema: list[tuple[str, str, str]], + valid_entities: tuple[SchemaEntity, ...], + potential_schema: tuple[tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( @@ -374,13 +385,13 @@ async def test_schema_from_text_llm_params( @pytest.mark.asyncio -async def test_schema_config_store_as_json(schema_config: GraphSchema) -> None: +async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") # store the schema config - schema_config.store_as_json(json_path) + graph_schema.store_as_json(json_path) # verify the file exists and has content assert os.path.exists(json_path) @@ -394,13 +405,13 @@ async def test_schema_config_store_as_json(schema_config: GraphSchema) -> None: @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(schema_config: GraphSchema) -> None: +async def test_schema_config_store_as_yaml(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") # Store the schema config - schema_config.store_as_yaml(yaml_path) + graph_schema.store_as_yaml(yaml_path) # Verify the file exists and has content assert os.path.exists(yaml_path) @@ -414,7 +425,7 @@ async def test_schema_config_store_as_yaml(schema_config: GraphSchema) -> None: @pytest.mark.asyncio -async def test_schema_config_from_file(schema_config: GraphSchema) -> None: +async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file paths with different extensions json_path = os.path.join(temp_dir, "schema.json") @@ -422,9 +433,9 @@ async def test_schema_config_from_file(schema_config: GraphSchema) -> None: yml_path = os.path.join(temp_dir, "schema.yml") # store the schema config in the different formats - schema_config.store_as_json(json_path) - schema_config.store_as_yaml(yaml_path) - schema_config.store_as_yaml(yml_path) + graph_schema.store_as_json(json_path) + graph_schema.store_as_yaml(yaml_path) + graph_schema.store_as_yaml(yml_path) # load using from_file which should detect the format based on extension json_schema = GraphSchema.from_file(json_path) @@ -443,7 +454,7 @@ async def test_schema_config_from_file(schema_config: GraphSchema) -> None: # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") - schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension + graph_schema.store_as_json(txt_path) # Store as JSON but with .txt extension with pytest.raises(ValueError, match="Unsupported file format"): GraphSchema.from_file(txt_path) From 26e9e13d96673cdbb7bf4331cd67a00e6a7ea1ba Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:13:19 +0200 Subject: [PATCH 05/16] Ruff --- .../components/entity_relation_extractor.py | 4 +--- tests/unit/experimental/components/test_schema.py | 13 +++++-------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index c28f0dfb4..7485a77d4 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -325,9 +325,7 @@ async def run( lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) - schema = schema or GraphSchema( - entities=(), relations=(), potential_schema=() - ) + schema = schema or GraphSchema(entities=(), relations=(), potential_schema=()) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 5761ce088..c86525f56 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -18,7 +18,6 @@ from unittest.mock import AsyncMock import pytest -from pydantic import ValidationError from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( @@ -95,9 +94,7 @@ def potential_schema_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: @pytest.fixture def potential_schema_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: - return ( - ("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"), - ) + return (("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"),) @pytest.fixture @@ -140,9 +137,7 @@ async def test_run_method( potential_schema: tuple[tuple[str, str, str], ...], ) -> None: schema = await schema_builder.run( - list(valid_entities), - list(valid_relations), - list(potential_schema) + list(valid_entities), list(valid_relations), list(potential_schema) ) assert schema.entities == valid_entities @@ -175,7 +170,9 @@ def test_create_schema_model_invalid_relation( ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - list(valid_entities), list(valid_relations), list(potential_schema_with_invalid_relation) + list(valid_entities), + list(valid_relations), + list(potential_schema_with_invalid_relation), ) assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value From 63d30ccb405424079e6d9735d63a795b7aaa3204 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:22:04 +0200 Subject: [PATCH 06/16] Update docstrings and CHANGELOG --- CHANGELOG.md | 1 + src/neo4j_graphrag/experimental/components/schema.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2fc5bcd9..d3deaf390 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ ### Changed +- `SchemaConfig` has been deprecated in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). - Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema. diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8dcaec7b7..456688e9d 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -324,9 +324,9 @@ def create_schema_model( and a Dictionary defining potential relationships. Args: - entities (List[SchemaEntity]): List of Entity objects. - relations (Optional[List[SchemaRelation]]): List of Relation objects. - potential_schema (Optional[List[Tuple[str, str, str]]]): Dictionary mapping entity names to Lists of relation names. + entities (Sequence[SchemaEntity]): List or tuple of SchemaEntity objects. + relations (Optional[Sequence[SchemaRelation]]): List or tuple of SchemaRelation objects. + potential_schema (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). Returns: GraphSchema: A configured schema object. @@ -353,9 +353,9 @@ async def run( Asynchronously constructs and returns a GraphSchema object. Args: - entities (List[SchemaEntity]): List of Entity objects. - relations (List[SchemaRelation]): List of Relation objects. - potential_schema (Dict[str, List[str]]): Dictionary mapping entity names to Lists of relation names. + entities (Sequence[SchemaEntity]): List or tuple of SchemaEntity objects. + relations (Sequence[SchemaRelation]): List or tuple of SchemaRelation objects. + potential_schema (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). Returns: GraphSchema: A configured schema object, constructed asynchronously. From 29ebecb4f2d6c55bb43271b2a10ac15455bb1c9c Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:33:12 +0200 Subject: [PATCH 07/16] Mypy and tests --- .../template_pipeline/simple_kg_builder.py | 32 +++++++++++-------- .../test_simple_kg_builder.py | 8 ++--- .../experimental/pipeline/test_kg_builder.py | 10 ++---- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 04c7c3bc9..a7ee4b917 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -19,7 +19,6 @@ Optional, Sequence, Union, - List, Tuple, Dict, cast, @@ -190,7 +189,9 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: def _process_schema_with_precedence( self, ) -> Tuple[ - List[SchemaEntity], List[SchemaRelation], Optional[List[Tuple[str, str, str]]] + Tuple[SchemaEntity, ...], + Tuple[SchemaRelation, ...], + Optional[Tuple[Tuple[str, str, str], ...]], ]: """ Process schema inputs according to precedence rules: @@ -205,42 +206,45 @@ def _process_schema_with_precedence( # schema takes precedence over individual components if isinstance(self.schema_, GraphSchema): # extract components from GraphSchema - entities = list(self.schema_.entities) + entities = self.schema_.entities # handle case where relations could be None if self.schema_.relations is not None: - relations = list(self.schema_.relations) + relations = self.schema_.relations else: - relations = [] + relations = () - potential_schema = list(self.schema_.potential_schema) + potential_schema = self.schema_.potential_schema else: - entities = [ + entities = tuple( SchemaEntity.from_text_or_dict(e) for e in cast( Dict[str, Any], self.schema_.get("entities", {}) ).values() - ] - relations = [ + ) + relations = tuple( SchemaRelation.from_text_or_dict(r) for r in cast( Dict[str, Any], self.schema_.get("relations", {}) ).values() - ] - potential_schema = self.schema_.get("potential_schema") + ) + ps = self.schema_.get("potential_schema") + potential_schema = tuple(ps) if ps else None else: # use individual components - entities = ( + entities = tuple( [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else [] ) - relations = ( + relations = tuple( [SchemaRelation.from_text_or_dict(r) for r in self.relations] if self.relations else [] ) - potential_schema = self.potential_schema + potential_schema = ( + tuple(self.potential_schema) if self.potential_schema else None + ) return entities, relations, potential_schema diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 8aa318cd3..1d188a03d 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -142,11 +142,9 @@ def test_simple_kg_pipeline_config_schema_run_params() -> None: potential_schema=[("Person", "KNOWS", "Person")], ) assert config._get_run_params_for_schema() == { - "entities": [SchemaEntity(label="Person")], - "relations": [SchemaRelation(label="KNOWS")], - "potential_schema": [ - ("Person", "KNOWS", "Person"), - ], + "entities": (SchemaEntity(label="Person"),), + "relations": (SchemaRelation(label="KNOWS"),), + "potential_schema": (("Person", "KNOWS", "Person"),), } diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index a6d3d4c42..2de9f1485 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -116,10 +116,6 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: from_pdf=True, ) - # assert kg_builder.entities == entities - # assert kg_builder.relations == relations - # assert kg_builder.potential_schema == potential_schema - file_path = "path/to/test.pdf" internal_entities = [SchemaEntity(label=label) for label in entities] @@ -132,9 +128,9 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: ) as mock_run: await kg_builder.run_async(file_path=file_path) pipe_inputs = mock_run.call_args[1]["data"] - assert pipe_inputs["schema"]["entities"] == internal_entities - assert pipe_inputs["schema"]["relations"] == internal_relations - assert pipe_inputs["schema"]["potential_schema"] == potential_schema + assert pipe_inputs["schema"]["entities"] == tuple(internal_entities) + assert pipe_inputs["schema"]["relations"] == tuple(internal_relations) + assert pipe_inputs["schema"]["potential_schema"] == tuple(potential_schema) def test_simple_kg_pipeline_on_error_invalid_value() -> None: From 5abb0cc5daefdef93e4ad308d5a13ccfc8da9250 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:46:00 +0200 Subject: [PATCH 08/16] Update examples --- README.md | 8 +- .../build_graph/simple_kg_builder_from_pdf.py | 8 +- .../simple_kg_builder_from_text.py | 8 +- examples/data/extracted_schema.json | 47 +++----- examples/data/extracted_schema.yaml | 109 +++++++----------- .../template_pipeline/simple_kg_builder.py | 8 +- 6 files changed, 74 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 60677256c..1c8622619 100644 --- a/README.md +++ b/README.md @@ -128,8 +128,10 @@ kg_builder = SimpleKGPipeline( llm=llm, driver=driver, embedder=embedder, - entities=entities, - relations=relations, + schema={ + "entities": entities, + "relations": relations, + }, on_error="IGNORE", from_pdf=False, ) @@ -365,7 +367,7 @@ When you're finished with your changes, create a pull request (PR) using the fol ## 🧪 Tests -To be able to run all tests, all extra packages needs to be installed. +To be able to run all tests, all extra packages needs to be installed. This is achieved by: ```bash diff --git a/examples/build_graph/simple_kg_builder_from_pdf.py b/examples/build_graph/simple_kg_builder_from_pdf.py index f7ad683da..17a1f71a0 100644 --- a/examples/build_graph/simple_kg_builder_from_pdf.py +++ b/examples/build_graph/simple_kg_builder_from_pdf.py @@ -47,9 +47,11 @@ async def define_and_run_pipeline( llm=llm, driver=neo4j_driver, embedder=OpenAIEmbeddings(), - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, + schema={ + "entities": ENTITIES, + "relations": RELATIONS, + "potential_schema": POTENTIAL_SCHEMA, + }, neo4j_database=DATABASE, ) return await kg_builder.run_async(file_path=str(file_path)) diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index 79b8c8791..584950309 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -71,9 +71,11 @@ async def define_and_run_pipeline( llm=llm, driver=neo4j_driver, embedder=OpenAIEmbeddings(), - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, + schema={ + "entities": ENTITIES, + "relations": RELATIONS, + "potential_schema": POTENTIAL_SCHEMA, + }, from_pdf=False, neo4j_database=DATABASE, ) diff --git a/examples/data/extracted_schema.json b/examples/data/extracted_schema.json index 0ec66639b..1073de7fc 100644 --- a/examples/data/extracted_schema.json +++ b/examples/data/extracted_schema.json @@ -1,6 +1,6 @@ { - "entities": { - "Company": { + "entities": [ + { "label": "Company", "description": "", "properties": [ @@ -26,7 +26,7 @@ } ] }, - "Person": { + { "label": "Person", "description": "", "properties": [ @@ -41,13 +41,13 @@ "description": "" }, { - "name": "yearJoined", + "name": "startYear", "type": "INTEGER", "description": "" } ] }, - "Product": { + { "label": "Product", "description": "", "properties": [ @@ -67,46 +67,30 @@ "description": "" } ] - }, - "Office": { - "label": "Office", - "description": "", - "properties": [ - { - "name": "location", - "type": "STRING", - "description": "" - } - ] } - }, - "relations": { - "FOUNDED_BY": { + ], + "relations": [ + { "label": "FOUNDED_BY", "description": "", "properties": [] }, - "WORKS_FOR": { + { "label": "WORKS_FOR", "description": "", "properties": [] }, - "MANAGES": { + { "label": "MANAGES", "description": "", "properties": [] }, - "DEVELOPED_BY": { + { "label": "DEVELOPED_BY", "description": "", "properties": [] - }, - "LOCATED_IN": { - "label": "LOCATED_IN", - "description": "", - "properties": [] } - }, + ], "potential_schema": [ [ "Company", @@ -121,17 +105,12 @@ [ "Person", "MANAGES", - "Office" + "Company" ], [ "Product", "DEVELOPED_BY", "Person" - ], - [ - "Company", - "LOCATED_IN", - "Office" ] ] } \ No newline at end of file diff --git a/examples/data/extracted_schema.yaml b/examples/data/extracted_schema.yaml index f2500799f..f0997c24b 100644 --- a/examples/data/extracted_schema.yaml +++ b/examples/data/extracted_schema.yaml @@ -1,74 +1,56 @@ entities: - Company: - label: Company +- label: Company + description: '' + properties: + - name: name + type: STRING description: '' - properties: - - name: name - type: STRING - description: '' - - name: foundedYear - type: INTEGER - description: '' - - name: revenue - type: FLOAT - description: '' - - name: valuation - type: FLOAT - description: '' - Person: - label: Person + - name: foundedYear + type: INTEGER description: '' - properties: - - name: name - type: STRING - description: '' - - name: position - type: STRING - description: '' - - name: yearJoined - type: INTEGER - description: '' - Product: - label: Product + - name: revenue + type: FLOAT description: '' - properties: - - name: name - type: STRING - description: '' - - name: launchYear - type: INTEGER - description: '' - - name: unitsSold - type: INTEGER - description: '' - Office: - label: Office + - name: valuation + type: FLOAT description: '' - properties: - - name: location - type: STRING - description: '' -relations: - FOUNDED_BY: - label: FOUNDED_BY +- label: Person + description: '' + properties: + - name: name + type: STRING + description: '' + - name: position + type: STRING description: '' - properties: [] - WORKS_FOR: - label: WORKS_FOR + - name: startYear + type: INTEGER description: '' - properties: [] - MANAGES: - label: MANAGES +- label: Product + description: '' + properties: + - name: name + type: STRING description: '' - properties: [] - DEVELOPED_BY: - label: DEVELOPED_BY + - name: launchYear + type: INTEGER description: '' - properties: [] - LOCATED_IN: - label: LOCATED_IN + - name: unitsSold + type: INTEGER description: '' - properties: [] +relations: +- label: FOUNDED_BY + description: '' + properties: [] +- label: WORKS_FOR + description: '' + properties: [] +- label: MANAGES + description: '' + properties: [] +- label: DEVELOPED_BY + description: '' + properties: [] potential_schema: - - Company - FOUNDED_BY @@ -78,10 +60,7 @@ potential_schema: - Company - - Person - MANAGES - - Office + - Company - - Product - DEVELOPED_BY - Person -- - Company - - LOCATED_IN - - Office diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index a7ee4b917..d09444a24 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -218,15 +218,11 @@ def _process_schema_with_precedence( else: entities = tuple( SchemaEntity.from_text_or_dict(e) - for e in cast( - Dict[str, Any], self.schema_.get("entities", {}) - ).values() + for e in self.schema_.get("entities", []) ) relations = tuple( SchemaRelation.from_text_or_dict(r) - for r in cast( - Dict[str, Any], self.schema_.get("relations", {}) - ).values() + for r in self.schema_.get("relations", []) ) ps = self.schema_.get("potential_schema") potential_schema = tuple(ps) if ps else None From 842dbe21e0bef42dc0d2506e5e7c52ec5bff99e3 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 16 May 2025 11:49:12 +0200 Subject: [PATCH 09/16] Ruff again --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index d09444a24..d8b7e9a71 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -20,8 +20,6 @@ Sequence, Union, Tuple, - Dict, - cast, ) import logging import warnings From 67221a8deb2a578006647bbdd71088782c8ada89 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 20 May 2025 12:26:06 +0200 Subject: [PATCH 10/16] Renaming fields after internal discussion --- CHANGELOG.md | 10 +- README.md | 10 +- docs/source/types.rst | 18 +- docs/source/user_guide_kg_builder.rst | 67 ++++--- .../build_graph/simple_kg_builder_from_pdf.py | 14 +- .../simple_kg_builder_from_text.py | 12 +- .../components/schema_builders/schema.py | 33 ++-- .../schema_builders/schema_from_text.py | 14 +- .../pipeline/kg_builder_from_pdf.py | 34 ++-- .../pipeline/kg_builder_from_text.py | 34 ++-- ...builder_two_documents_entity_resolution.py | 36 ++-- ...l_graph_to_entity_graph_single_pipeline.py | 34 ++-- ...cal_graph_to_entity_graph_two_pipelines.py | 34 ++-- examples/data/extracted_schema.json | 67 ++++--- examples/data/extracted_schema.yaml | 48 +++-- examples/kg_builder.py | 32 +-- .../components/entity_relation_extractor.py | 22 ++- .../experimental/components/schema.py | 146 +++++++------- .../template_pipeline/simple_kg_builder.py | 56 +++--- .../experimental/pipeline/kg_builder.py | 8 +- .../experimental/pipeline/types/schema.py | 2 +- src/neo4j_graphrag/generation/prompts.py | 10 +- .../test_kg_builder_pipeline_e2e.py | 64 +++--- .../experimental/test_simplekgpipeline_e2e.py | 6 +- .../test_entity_relation_extractor.py | 50 ++--- .../experimental/components/test_schema.py | 187 +++++++++--------- .../test_simple_kg_builder.py | 10 +- .../experimental/pipeline/test_kg_builder.py | 16 +- 28 files changed, 552 insertions(+), 522 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3deaf390..378ec3c84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,17 @@ ### Changed -- `SchemaConfig` has been deprecated in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). +#### Strict mode + - Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema. +#### Schema definition + +- The `SchemaEntity` model has been renamed `NodeType`. +- The `SchemaRelation` model has been renamed `RelationshipType`. +- The `SchemaProperty` model has been renamed `PropertyType`. +- `SchemaConfig` has been removed in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). `entities`, `relations` and `potential_schema` fields have also been renamed `node_types`, `relationship_types` and `patterns` respectively. + ## 1.7.0 diff --git a/README.md b/README.md index 1c8622619..a8bd4ccc0 100644 --- a/README.md +++ b/README.md @@ -102,9 +102,9 @@ NEO4J_PASSWORD = "password" driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) # List the entities and relations the LLM should look for in the text -entities = ["Person", "House", "Planet"] -relations = ["PARENT_OF", "HEIR_OF", "RULES"] -potential_schema = [ +node_types = ["Person", "House", "Planet"] +relationship_types = ["PARENT_OF", "HEIR_OF", "RULES"] +patterns = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -129,8 +129,8 @@ kg_builder = SimpleKGPipeline( driver=driver, embedder=embedder, schema={ - "entities": entities, - "relations": relations, + "node_types": node_types, + "relationship_types": relationship_types, }, on_error="IGNORE", from_pdf=False, diff --git a/docs/source/types.rst b/docs/source/types.rst index 8ea5d05df..267e310d3 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -75,20 +75,20 @@ KGWriterModel .. autoclass:: neo4j_graphrag.experimental.components.kg_writer.KGWriterModel -SchemaProperty -============== +PropertyType +============ -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaProperty +.. autoclass:: neo4j_graphrag.experimental.components.schema.PropertyType -SchemaEntity -============ +NodeType +======== -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaEntity +.. autoclass:: neo4j_graphrag.experimental.components.schema.NodeType -SchemaRelation -============== +RelationshipType +================ -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaRelation +.. autoclass:: neo4j_graphrag.experimental.components.schema.RelationshipType GraphSchema =========== diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index bb0733170..d29014fca 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -21,7 +21,7 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of - **Data loader**: extract text from files (PDFs, ...). - **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit). - **Chunk embedder** (optional): compute the chunk embeddings. -- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. +- **Schema builder**: provide a schema to ground the LLM extracted node and relationship types and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. - **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional). - **Entity and relation extractor**: extract relevant entities and relations from the text. - **Knowledge Graph writer**: save the identified entities and relations. @@ -73,18 +73,18 @@ Customizing the SimpleKGPipeline Graph Schema ------------ -It is possible to guide the LLM by supplying a list of entities, relationships, -and instructions on how to connect them. However, note that the extracted graph +It is possible to guide the LLM by supplying a list of node and relationship types, +and instructions on how to connect them (patterns). However, note that the extracted graph may not fully adhere to these guidelines unless schema enforcement is enabled -(see :ref:`Schema Enforcement Behaviour`). Entities and relationships can be represented +(see :ref:`Schema Enforcement Behaviour`). Node and relationship types can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, it must include a label key and can optionally include description and properties keys, as shown below: .. code:: python - ENTITIES = [ - # entities can be defined with a simple label... + NODE_TYPES = [ + # node types can be defined with a simple label... "Person", # ... or with a dict if more details are needed, # such as a description: @@ -93,7 +93,7 @@ as shown below: {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: - RELATIONS = [ + RELATIONSHIP_TYPES = [ "PARENT_OF", { "label": "HEIR_OF", @@ -102,13 +102,13 @@ as shown below: {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, ] -The `potential_schema` is defined by a list of triplet in the format: +The `patterns` is defined by a list of triplet in the format: `(source_node_label, relationship_label, target_node_label)`. For instance: .. code:: python - POTENTIAL_SCHEMA = [ + PATTERNS = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -122,15 +122,15 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated kg_builder = SimpleKGPipeline( # ... schema={ - "entities": ENTITIES, - "relations": RELATIONS, - "potential_schema": POTENTIAL_SCHEMA + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS }, # ... ) .. note:: - By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromTextExtractor`). + By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction`). Extra configurations -------------------- @@ -419,9 +419,8 @@ within the configuration file. "neo4j_database": "myDb", "on_error": "IGNORE", "prompt_template": "...", - "schema": { - "entities": [ + "node_types": [ "Person", { "label": "House", @@ -438,7 +437,7 @@ within the configuration file. ] } ], - "relations": [ + "relationship_types": [ "PARENT_OF", { "label": "HEIR_OF", @@ -451,7 +450,7 @@ within the configuration file. ] } ], - "potential_schema": [ + "patterns": [ ["Person", "PARENT_OF", "Person"], ["Person", "HEIR_OF", "House"], ["House", "RULES", "Planet"] @@ -473,7 +472,7 @@ or in YAML: on_error: IGNORE prompt_template: ... schema: - entities: + node_types: - Person - label: House description: Family the person belongs to @@ -486,7 +485,7 @@ or in YAML: type: STRING - name: weather type: STRING - relations: + relationship_types: - PARENT_OF - label: HEIR_OF description: Used for inheritor relationship between father and sons @@ -494,7 +493,7 @@ or in YAML: properties: - name: fromYear type: INTEGER - potential_schema: + patterns: - ["Person", "PARENT_OF", "Person"] - ["Person", "HEIR_OF", "House"] - ["House", "RULES", "Planet"] @@ -747,12 +746,12 @@ Optionally, the document and chunk node labels can be configured using a `Lexica Schema Builder ============== -The schema is used to try and ground the LLM to a list of possible entities and relations of interest. +The schema is used to try and ground the LLM to a list of possible node and relationship types of interest. So far, schema must be manually created by specifying: -- **Entities** the LLM should look for in the text, including their properties (name and type). -- **Relations** of interest between these entities, including the relation properties (name and type). -- **Triplets** to define the start (source) and end (target) entity types for each relation. +- **Node types** the LLM should look for in the text, including their properties (name and type). +- **Relationship types** of interest between these node types, including the relationship properties (name and type). +- **Patterns** (triplets) to define the start (source) and end (target) entity types for each relationship. Here is a code block illustrating these concepts: @@ -760,16 +759,16 @@ Here is a code block illustrating these concepts: from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) schema_builder = SchemaBuilder() await schema_builder.run( - entities=[ - SchemaEntity( + node_types=[ + NodeType( label="Person", properties=[ SchemaProperty(name="name", type="STRING"), @@ -777,7 +776,7 @@ Here is a code block illustrating these concepts: SchemaProperty(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ SchemaProperty(name="name", type="STRING"), @@ -785,15 +784,15 @@ Here is a code block illustrating these concepts: ], ), ], - relations=[ - SchemaRelation( + relationship_types=[ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - possible_schema=[ + patterns=[ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/build_graph/simple_kg_builder_from_pdf.py b/examples/build_graph/simple_kg_builder_from_pdf.py index 17a1f71a0..2cfc85134 100644 --- a/examples/build_graph/simple_kg_builder_from_pdf.py +++ b/examples/build_graph/simple_kg_builder_from_pdf.py @@ -27,11 +27,11 @@ file_path = root_dir / "data" / "Harry Potter and the Chamber of Secrets Summary.pdf" -# Instantiate Entity and Relation objects. This defines the +# Instantiate NodeType and RelationshipType objects. This defines the # entities and relations the LLM will be looking for in the text. -ENTITIES = ["Person", "Organization", "Location"] -RELATIONS = ["SITUATED_AT", "INTERACTS", "LED_BY"] -POTENTIAL_SCHEMA = [ +NODE_TYPES = ["Person", "Organization", "Location"] +RELATIONSHIP_TYPES = ["SITUATED_AT", "INTERACTS", "LED_BY"] +PATTERNS = [ ("Person", "SITUATED_AT", "Location"), ("Person", "INTERACTS", "Person"), ("Organization", "LED_BY", "Person"), @@ -48,9 +48,9 @@ async def define_and_run_pipeline( driver=neo4j_driver, embedder=OpenAIEmbeddings(), schema={ - "entities": ENTITIES, - "relations": RELATIONS, - "potential_schema": POTENTIAL_SCHEMA, + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS, }, neo4j_database=DATABASE, ) diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index 584950309..548cbd9eb 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -37,7 +37,7 @@ # Instantiate Entity and Relation objects. This defines the # entities and relations the LLM will be looking for in the text. -ENTITIES: list[EntityInputType] = [ +NODE_TYPES: list[EntityInputType] = [ # entities can be defined with a simple label... "Person", # ... or with a dict if more details are needed, @@ -47,7 +47,7 @@ {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: -RELATIONS: list[RelationInputType] = [ +RELATIONSHIP_TYPES: list[RelationInputType] = [ "PARENT_OF", { "label": "HEIR_OF", @@ -55,7 +55,7 @@ }, {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, ] -POTENTIAL_SCHEMA = [ +PATTERNS = [ ("Person", "PARENT_OF", "Person"), ("Person", "HEIR_OF", "House"), ("House", "RULES", "Planet"), @@ -72,9 +72,9 @@ async def define_and_run_pipeline( driver=neo4j_driver, embedder=OpenAIEmbeddings(), schema={ - "entities": ENTITIES, - "relations": RELATIONS, - "potential_schema": POTENTIAL_SCHEMA, + "node_types": NODE_TYPES, + "relationship_types": RELATIONSHIP_TYPES, + "patterns": PATTERNS, }, from_pdf=False, neo4j_database=DATABASE, diff --git a/examples/customize/build_graph/components/schema_builders/schema.py b/examples/customize/build_graph/components/schema_builders/schema.py index 6333fdd97..6ca408dee 100644 --- a/examples/customize/build_graph/components/schema_builders/schema.py +++ b/examples/customize/build_graph/components/schema_builders/schema.py @@ -14,43 +14,44 @@ # limitations under the License. from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) async def main() -> None: schema_builder = SchemaBuilder() - await schema_builder.run( - entities=[ - SchemaEntity( + result = await schema_builder.run( + node_types=[ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), ], - relations=[ - SchemaRelation( + relationship_types=[ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - potential_schema=[ + patterns=[ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], ) + print(result) diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index 7840cbd17..a36ef90ec 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -92,14 +92,14 @@ async def extract_and_save_schema() -> None: inferred_schema.store_as_yaml(YAML_FILE_PATH) print("\nExtracted Schema Summary:") - print(f"Entities: {list(inferred_schema.entities)}") + print(f"Node types: {list(inferred_schema.node_types)}") print( - f"Relations: {list(inferred_schema.relations if inferred_schema.relations else [])}" + f"Relationship types: {list(inferred_schema.relationship_types if inferred_schema.relationship_types else [])}" ) - if inferred_schema.potential_schema: - print("\nPotential Schema:") - for entity1, relation, entity2 in inferred_schema.potential_schema: + if inferred_schema.patterns: + print("\nPatterns:") + for entity1, relation, entity2 in inferred_schema.patterns: print(f" {entity1} --[{relation}]--> {entity2}") finally: @@ -122,8 +122,8 @@ async def main() -> None: schema_from_json = GraphSchema.from_file(JSON_FILE_PATH) schema_from_yaml = GraphSchema.from_file(YAML_FILE_PATH) - print(f"Entities in JSON schema: {list(schema_from_json.entities)}") - print(f"Entities in YAML schema: {list(schema_from_yaml.entities)}") + print(f"Node types in JSON schema: {list(schema_from_json.node_types)}") + print(f"Node types in YAML schema: {list(schema_from_yaml.node_types)}") if __name__ == "__main__": diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py index ab11206da..ea727fe3c 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_pdf.py @@ -25,8 +25,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -45,35 +45,35 @@ async def define_and_run_pipeline( from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( + node_types = [ + NodeType(label="PERSON", description="An individual human being."), + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="LOCATION", description="A location or place."), - SchemaEntity( + NodeType(label="LOCATION", description="A location or place."), + NodeType( label="HORCRUX", description="A magical item in the Harry Potter universe.", ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="SITUATED_AT", description="Indicates the location of a person." ), - SchemaRelation( + RelationshipType( label="LED_BY", description="Indicates the leader of an organization.", ), - SchemaRelation( + RelationshipType( label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation( + RelationshipType( label="INTERACTS", description="The interaction between two people." ), ] - potential_schema = [ + patterns = [ ("PERSON", "SITUATED_AT", "LOCATION"), ("PERSON", "INTERACTS", "PERSON"), ("PERSON", "OWNS", "HORCRUX"), @@ -114,12 +114,12 @@ async def define_and_run_pipeline( pipe_inputs = { "pdf_loader": { - "filepath": "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf" + "filepath": "examples/data/Harry Potter and the Death Hallows Summary.pdf" }, "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, } return await pipe.run(pipe_inputs) diff --git a/examples/customize/build_graph/pipeline/kg_builder_from_text.py b/examples/customize/build_graph/pipeline/kg_builder_from_text.py index 907a02825..3a9e30911 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_from_text.py +++ b/examples/customize/build_graph/pipeline/kg_builder_from_text.py @@ -25,9 +25,9 @@ from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -95,38 +95,38 @@ async def define_and_run_pipeline( the University of Bern in Switzerland and the University of Oxford.""" }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py index d6f5e9ae8..eda2b4219 100644 --- a/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py +++ b/examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py @@ -28,9 +28,9 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -92,35 +92,35 @@ async def define_and_run_pipeline( pipe_inputs = { "loader": {}, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_FOR", ), - SchemaRelation( + RelationshipType( label="FRIEND", ), - SchemaRelation( + RelationshipType( label="ENEMY", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_FOR", "Organization"), ("Person", "FRIEND", "Person"), ("Person", "ENEMY", "Person"), @@ -129,8 +129,8 @@ async def define_and_run_pipeline( } # run the pipeline for each documents for document in [ - "examples/pipeline/Harry Potter and the Chamber of Secrets Summary.pdf", - "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf", + "examples/data/Harry Potter and the Chamber of Secrets Summary.pdf", + "examples/data/Harry Potter and the Death Hallows Summary.pdf", ]: pipe_inputs["loader"]["filepath"] = document await pipe.run(pipe_inputs) diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py index 6867d9068..daaab51a5 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py @@ -16,9 +16,9 @@ from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -118,38 +118,38 @@ async def define_and_run_pipeline( }, }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py index 0fd354db4..b5b6b5273 100644 --- a/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py +++ b/examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py @@ -18,9 +18,9 @@ from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -138,38 +138,38 @@ async def read_chunk_and_perform_entity_extraction( "lexical_graph_config": lexical_graph_config, }, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="country", type="STRING"), + PropertyType(name="name", type="STRING"), + PropertyType(name="country", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Field", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="WORKED_ON", ), - SchemaRelation( + RelationshipType( label="WORKED_FOR", ), ], - "potential_schema": [ + "patterns": [ ("Person", "WORKED_ON", "Field"), ("Person", "WORKED_FOR", "Organization"), ], diff --git a/examples/data/extracted_schema.json b/examples/data/extracted_schema.json index 1073de7fc..ce36153d6 100644 --- a/examples/data/extracted_schema.json +++ b/examples/data/extracted_schema.json @@ -1,7 +1,7 @@ { - "entities": [ + "node_types": [ { - "label": "Company", + "label": "Person", "description": "", "properties": [ { @@ -10,24 +10,19 @@ "description": "" }, { - "name": "foundedYear", - "type": "INTEGER", - "description": "" - }, - { - "name": "revenue", - "type": "FLOAT", + "name": "position", + "type": "STRING", "description": "" }, { - "name": "valuation", - "type": "FLOAT", + "name": "startYear", + "type": "INTEGER", "description": "" } ] }, { - "label": "Person", + "label": "Company", "description": "", "properties": [ { @@ -36,13 +31,18 @@ "description": "" }, { - "name": "position", - "type": "STRING", + "name": "foundedYear", + "type": "INTEGER", "description": "" }, { - "name": "startYear", - "type": "INTEGER", + "name": "revenue", + "type": "FLOAT", + "description": "" + }, + { + "name": "valuation", + "type": "FLOAT", "description": "" } ] @@ -67,36 +67,42 @@ "description": "" } ] + }, + { + "label": "Office", + "description": "", + "properties": [ + { + "name": "location", + "type": "STRING", + "description": "" + } + ] } ], - "relations": [ + "relationship_types": [ { - "label": "FOUNDED_BY", + "label": "WORKS_FOR", "description": "", "properties": [] }, { - "label": "WORKS_FOR", + "label": "MANAGES", "description": "", "properties": [] }, { - "label": "MANAGES", + "label": "DEVELOPED_BY", "description": "", "properties": [] }, { - "label": "DEVELOPED_BY", + "label": "LOCATED_IN", "description": "", "properties": [] } ], - "potential_schema": [ - [ - "Company", - "FOUNDED_BY", - "Person" - ], + "patterns": [ [ "Person", "WORKS_FOR", @@ -105,12 +111,17 @@ [ "Person", "MANAGES", - "Company" + "Office" ], [ "Product", "DEVELOPED_BY", "Person" + ], + [ + "Company", + "LOCATED_IN", + "Office" ] ] } \ No newline at end of file diff --git a/examples/data/extracted_schema.yaml b/examples/data/extracted_schema.yaml index f0997c24b..efdd8e733 100644 --- a/examples/data/extracted_schema.yaml +++ b/examples/data/extracted_schema.yaml @@ -1,4 +1,16 @@ -entities: +node_types: +- label: Person + description: '' + properties: + - name: name + type: STRING + description: '' + - name: position + type: STRING + description: '' + - name: startYear + type: INTEGER + description: '' - label: Company description: '' properties: @@ -14,18 +26,6 @@ entities: - name: valuation type: FLOAT description: '' -- label: Person - description: '' - properties: - - name: name - type: STRING - description: '' - - name: position - type: STRING - description: '' - - name: startYear - type: INTEGER - description: '' - label: Product description: '' properties: @@ -38,10 +38,13 @@ entities: - name: unitsSold type: INTEGER description: '' -relations: -- label: FOUNDED_BY +- label: Office description: '' - properties: [] + properties: + - name: location + type: STRING + description: '' +relationship_types: - label: WORKS_FOR description: '' properties: [] @@ -51,16 +54,19 @@ relations: - label: DEVELOPED_BY description: '' properties: [] -potential_schema: -- - Company - - FOUNDED_BY - - Person +- label: LOCATED_IN + description: '' + properties: [] +patterns: - - Person - WORKS_FOR - Company - - Person - MANAGES - - Company + - Office - - Product - DEVELOPED_BY - Person +- - Company + - LOCATED_IN + - Office diff --git a/examples/kg_builder.py b/examples/kg_builder.py index 650473e41..c98f0c069 100644 --- a/examples/kg_builder.py +++ b/examples/kg_builder.py @@ -32,8 +32,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -50,35 +50,35 @@ async def define_and_run_pipeline( from neo4j_graphrag.experimental.pipeline import Pipeline # Instantiate Entity and Relation objects - entities = [ - SchemaEntity(label="PERSON", description="An individual human being."), - SchemaEntity( + node_types = [ + NodeType(label="PERSON", description="An individual human being."), + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="LOCATION", description="A location or place."), - SchemaEntity( + NodeType(label="LOCATION", description="A location or place."), + NodeType( label="HORCRUX", description="A magical item in the Harry Potter universe.", ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="SITUATED_AT", description="Indicates the location of a person." ), - SchemaRelation( + RelationshipType( label="LED_BY", description="Indicates the leader of an organization.", ), - SchemaRelation( + RelationshipType( label="OWNS", description="Indicates the ownership of an item such as a Horcrux.", ), - SchemaRelation( + RelationshipType( label="INTERACTS", description="The interaction between two people." ), ] - potential_schema = [ + patterns = [ ("PERSON", "SITUATED_AT", "LOCATION"), ("PERSON", "INTERACTS", "PERSON"), ("PERSON", "OWNS", "HORCRUX"), @@ -121,9 +121,9 @@ async def define_and_run_pipeline( "filepath": "examples/pipeline/Harry Potter and the Death Hallows Summary.pdf" }, "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, } return await pipe.run(pipe_inputs) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 7485a77d4..910dd13a1 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -25,7 +25,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import GraphSchema, SchemaProperty +from neo4j_graphrag.experimental.components.schema import GraphSchema, PropertyType from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, @@ -325,7 +325,9 @@ async def run( lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) - schema = schema or GraphSchema(entities=(), relations=(), potential_schema=()) + schema = schema or GraphSchema( + node_types=(), relationship_types=(), patterns=() + ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ @@ -351,7 +353,7 @@ def validate_chunk( - Enforce schema if schema enforcement mode is on and schema is provided """ if self.enforce_schema != SchemaEnforcementMode.NONE: - if not schema or not schema.entities: # schema is not provided + if not schema or not schema.node_types: # schema is not provided logger.warning( "Schema enforcement is ON but the guiding schema is not provided." ) @@ -400,7 +402,7 @@ def _enforce_nodes( valid_nodes = [] for node in extracted_nodes: - schema_entity = schema.entity_from_label(node.label) + schema_entity = schema.node_type_from_label(node.label) if not schema_entity: continue allowed_props = schema_entity.properties or [] @@ -446,10 +448,10 @@ def _enforce_relationships( valid_nodes = {node.id: node.label for node in filtered_nodes} - potential_schema = schema.potential_schema + patterns = schema.patterns for rel in extracted_relationships: - schema_relation = schema.relations.get(rel.type) + schema_relation = schema.relationship_type_from_label(rel.type) if not schema_relation: logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") continue @@ -467,13 +469,13 @@ def _enforce_relationships( end_label = valid_nodes[rel.end_node_id] tuple_valid = True - if potential_schema: - tuple_valid = (start_label, rel.type, end_label) in potential_schema + if patterns: + tuple_valid = (start_label, rel.type, end_label) in patterns reverse_tuple_valid = ( end_label, rel.type, start_label, - ) in potential_schema + ) in patterns if not tuple_valid and not reverse_tuple_valid: logger.debug(f"PRUNING:: {rel} not in the potential schema") @@ -498,7 +500,7 @@ def _enforce_relationships( return valid_rels def _enforce_properties( - self, properties: Dict[str, Any], valid_properties: List[SchemaProperty] + self, properties: Dict[str, Any], valid_properties: List[PropertyType] ) -> Dict[str, Any]: """ Filter properties. diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 456688e9d..ad1ca4db4 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -44,7 +44,7 @@ from neo4j_graphrag.llm import LLMInterface -class SchemaProperty(BaseModel): +class PropertyType(BaseModel): """ Represents a property on a node or relationship in the graph. """ @@ -72,36 +72,36 @@ class SchemaProperty(BaseModel): ) -class SchemaEntity(BaseModel): +class NodeType(BaseModel): """ Represents a possible node in the graph. """ label: str description: str = "" - properties: list[SchemaProperty] = [] + properties: list[PropertyType] = [] @classmethod def from_text_or_dict(cls, input: EntityInputType) -> Self: - if isinstance(input, SchemaEntity): + if isinstance(input, NodeType): return input if isinstance(input, str): return cls(label=input) return cls.model_validate(input) -class SchemaRelation(BaseModel): +class RelationshipType(BaseModel): """ Represents a possible relationship between nodes in the graph. """ label: str description: str = "" - properties: list[SchemaProperty] = [] + properties: list[PropertyType] = [] @classmethod def from_text_or_dict(cls, input: RelationInputType) -> Self: - if isinstance(input, SchemaRelation): + if isinstance(input, RelationshipType): return input if isinstance(input, str): return cls(label=input) @@ -109,12 +109,12 @@ def from_text_or_dict(cls, input: RelationInputType) -> Self: class GraphSchema(DataModel): - entities: Tuple[SchemaEntity, ...] - relations: Optional[Tuple[SchemaRelation, ...]] = None - potential_schema: Optional[Tuple[Tuple[str, str, str], ...]] = None + node_types: Tuple[NodeType, ...] + relationship_types: Optional[Tuple[RelationshipType, ...]] = None + patterns: Optional[Tuple[Tuple[str, str, str], ...]] = None - _entity_index: dict[str, SchemaEntity] = PrivateAttr() - _relation_index: dict[str, SchemaRelation] = PrivateAttr() + _node_type_index: dict[str, NodeType] = PrivateAttr() + _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() model_config = ConfigDict( frozen=True, @@ -122,40 +122,42 @@ class GraphSchema(DataModel): @model_validator(mode="after") def check_schema(self) -> Self: - self._entity_index = {e.label: e for e in self.entities} - self._relation_index = ( - {r.label: r for r in self.relations} if self.relations else {} + self._node_type_index = {node.label: node for node in self.node_types} + self._relationship_type_index = ( + {r.label: r for r in self.relationship_types} + if self.relationship_types + else {} ) - relations = self.relations or tuple() - potential_schema = self.potential_schema or tuple() + relationship_types = self.relationship_types or tuple() + patterns = self.patterns or tuple() - if potential_schema: - if not relations: + if patterns: + if not relationship_types: raise SchemaValidationError( "Relations must also be provided when using a potential schema." ) - for entity1, relation, entity2 in potential_schema: - if entity1 not in self._entity_index: + for entity1, relation, entity2 in patterns: + if entity1 not in self._node_type_index: raise SchemaValidationError( f"Entity '{entity1}' is not defined in the provided entities." ) - if relation not in self._relation_index: + if relation not in self._relationship_type_index: raise SchemaValidationError( f"Relation '{relation}' is not defined in the provided relations." ) - if entity2 not in self._entity_index: + if entity2 not in self._node_type_index: raise SchemaValidationError( f"Entity '{entity2}' is not defined in the provided entities." ) return self - def entity_from_label(self, label: str) -> Optional[SchemaEntity]: - return self._entity_index.get(label) + def node_type_from_label(self, label: str) -> Optional[NodeType]: + return self._node_type_index.get(label) - def relation_from_label(self, label: str) -> Optional[SchemaRelation]: - return self._relation_index.get(label) + def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: + return self._relationship_type_index.get(label) def store_as_json(self, file_path: str) -> None: """ @@ -176,12 +178,12 @@ def store_as_yaml(self, file_path: str) -> None: """ # create a copy of the data and convert tuples to lists for YAML compatibility data = self.model_dump() - if data.get("entities"): - data["entities"] = list(data["entities"]) - if data.get("relations"): - data["relations"] = list(data["relations"]) - if data.get("potential_schema"): - data["potential_schema"] = [list(item) for item in data["potential_schema"]] + if data.get("node_types"): + data["node_types"] = list(data["node_types"]) + if data.get("relationship_types"): + data["relationship_types"] = list(data["relationship_types"]) + if data.get("patterns"): + data["patterns"] = [list(item) for item in data["patterns"]] with open(file_path, "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) @@ -265,38 +267,38 @@ class SchemaBuilder(Component): from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.pipeline import Pipeline - entities = [ - SchemaEntity( + node_types = [ + NodeType( label="PERSON", description="An individual human being.", properties=[ - SchemaProperty( + PropertyType( name="name", type="STRING", description="The name of the person" ) ], ), - SchemaEntity( + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", properties=[ - SchemaProperty( + PropertyType( name="name", type="STRING", description="The name of the organization" ) ], ), ] - relations = [ - SchemaRelation( + relationship_types = [ + RelationshipType( label="EMPLOYED_BY", description="Indicates employment relationship." ), ] - potential_schema = [ + patterns = [ ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ] pipe = Pipeline() @@ -304,9 +306,9 @@ class SchemaBuilder(Component): pipe.add_component(schema_builder, "schema_builder") pipe_inputs = { "schema": { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, }, ... } @@ -315,18 +317,18 @@ class SchemaBuilder(Component): @staticmethod def create_schema_model( - entities: Sequence[SchemaEntity], - relations: Optional[Sequence[SchemaRelation]] = None, - potential_schema: Optional[Sequence[Tuple[str, str, str]]] = None, + node_types: Sequence[NodeType], + relationship_types: Optional[Sequence[RelationshipType]] = None, + patterns: Optional[Sequence[Tuple[str, str, str]]] = None, ) -> GraphSchema: """ Creates a GraphSchema object from Lists of Entity and Relation objects and a Dictionary defining potential relationships. Args: - entities (Sequence[SchemaEntity]): List or tuple of SchemaEntity objects. - relations (Optional[Sequence[SchemaRelation]]): List or tuple of SchemaRelation objects. - potential_schema (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). + node_types (Sequence[NodeType]): List or tuple of NodeType objects. + relationship_types (Optional[Sequence[RelationshipType]]): List or tuple of RelationshipType objects. + patterns (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). Returns: GraphSchema: A configured schema object. @@ -334,9 +336,9 @@ def create_schema_model( try: return GraphSchema.model_validate( dict( - entities=entities, - relations=relations, - potential_schema=potential_schema, + node_types=node_types, + relationship_types=relationship_types, + patterns=patterns, ) ) except (ValidationError, SchemaValidationError) as e: @@ -345,22 +347,22 @@ def create_schema_model( @validate_call async def run( self, - entities: Sequence[SchemaEntity], - relations: Optional[Sequence[SchemaRelation]] = None, - potential_schema: Optional[Sequence[Tuple[str, str, str]]] = None, + node_types: Sequence[NodeType], + relationship_types: Optional[Sequence[RelationshipType]] = None, + patterns: Optional[Sequence[Tuple[str, str, str]]] = None, ) -> GraphSchema: """ Asynchronously constructs and returns a GraphSchema object. Args: - entities (Sequence[SchemaEntity]): List or tuple of SchemaEntity objects. - relations (Sequence[SchemaRelation]): List or tuple of SchemaRelation objects. - potential_schema (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label). + node_types (Sequence[NodeType]): Sequence of NodeType objects. + relationship_types (Sequence[RelationshipType]): Sequence of RelationshipType objects. + patterns (Optional[Sequence[Tuple[str, str, str]]]): Sequence of triplets: (source_entity_label, relation_label, target_entity_label). Returns: GraphSchema: A configured schema object, constructed asynchronously. """ - return self.create_schema_model(entities, relations, potential_schema) + return self.create_schema_model(node_types, relationship_types, patterns) class SchemaFromTextExtractor(Component): @@ -429,20 +431,20 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema except json.JSONDecodeError as exc: raise SchemaExtractionError("LLM response is not valid JSON.") from exc - extracted_entities: List[Dict[str, Any]] = ( - extracted_schema.get("entities") or [] + extracted_node_types: List[Dict[str, Any]] = ( + extracted_schema.get("node_types") or [] ) - extracted_relations: Optional[List[Dict[str, Any]]] = extracted_schema.get( - "relations" + extracted_relationship_types: Optional[List[Dict[str, Any]]] = ( + extracted_schema.get("relationship_types") ) - potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( - "potential_schema" + extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( + "patterns" ) return GraphSchema.model_validate( { - "entities": extracted_entities, - "relations": extracted_relations, - "potential_schema": potential_schema, + "node_types": extracted_node_types, + "relationship_types": extracted_relationship_types, + "patterns": extracted_patterns, } ) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index d8b7e9a71..ac4fc04b8 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -42,8 +42,8 @@ from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, GraphSchema, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, SchemaFromTextExtractor, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter @@ -187,8 +187,8 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: def _process_schema_with_precedence( self, ) -> Tuple[ - Tuple[SchemaEntity, ...], - Tuple[SchemaRelation, ...], + Tuple[NodeType, ...], + Tuple[RelationshipType, ...], Optional[Tuple[Tuple[str, str, str], ...]], ]: """ @@ -198,49 +198,47 @@ def _process_schema_with_precedence( 3. Otherwise, use individual schema components Returns: - Tuple of (entities, relations, potential_schema) + Tuple of (node_types, relationship_types, patterns) """ if self.schema_ is not None: # schema takes precedence over individual components if isinstance(self.schema_, GraphSchema): # extract components from GraphSchema - entities = self.schema_.entities + node_types = self.schema_.node_types # handle case where relations could be None - if self.schema_.relations is not None: - relations = self.schema_.relations + if self.schema_.relationship_types is not None: + relationship_types = self.schema_.relationship_types else: - relations = () + relationship_types = () - potential_schema = self.schema_.potential_schema + patterns = self.schema_.patterns else: - entities = tuple( - SchemaEntity.from_text_or_dict(e) - for e in self.schema_.get("entities", []) + node_types = tuple( + NodeType.from_text_or_dict(e) + for e in self.schema_.get("node_types", ()) ) - relations = tuple( - SchemaRelation.from_text_or_dict(r) - for r in self.schema_.get("relations", []) + relationship_types = tuple( + RelationshipType.from_text_or_dict(r) + for r in self.schema_.get("node_types", ()) ) ps = self.schema_.get("potential_schema") - potential_schema = tuple(ps) if ps else None + patterns = tuple(ps) if ps else None else: # use individual components - entities = tuple( - [SchemaEntity.from_text_or_dict(e) for e in self.entities] + node_types = tuple( + [NodeType.from_text_or_dict(e) for e in self.entities] if self.entities else [] ) - relations = tuple( - [SchemaRelation.from_text_or_dict(r) for r in self.relations] + relationship_types = tuple( + [RelationshipType.from_text_or_dict(r) for r in self.relations] if self.relations else [] ) - potential_schema = ( - tuple(self.potential_schema) if self.potential_schema else None - ) + patterns = tuple(self.potential_schema) if self.potential_schema else None - return entities, relations, potential_schema + return node_types, relationship_types, patterns def _get_run_params_for_schema(self) -> dict[str, Any]: if not self.has_user_provided_schema(): @@ -248,14 +246,14 @@ def _get_run_params_for_schema(self) -> dict[str, Any]: return {} else: # process schema components according to precedence rules - entities, relations, potential_schema = ( + node_types, relationship_types, patterns = ( self._process_schema_with_precedence() ) return { - "entities": entities, - "relations": relations, - "potential_schema": potential_schema, + "node_types": node_types, + "relationship_types": relationship_types, + "patterns": patterns, } def _get_extractor(self) -> EntityRelationExtractor: diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 924670b90..d46ddc046 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -60,15 +60,15 @@ class SimpleKGPipeline: schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining entities, relations, and potential schema relationships. This is the recommended way to provide schema information. - entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either: + entities (Optional[List[Union[str, dict[str, str], NodeType]]]): DEPRECATED. A list of either: - str: entity labels - - dict: following the SchemaEntity schema, ie with label, description and properties keys + - dict: following the NodeType schema, ie with label, description and properties keys - relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): DEPRECATED. A list of either: + relations (Optional[List[Union[str, dict[str, str], RelationshipType]]]): DEPRECATED. A list of either: - str: relation label - - dict: following the SchemaRelation schema, ie with label, description and properties keys + - dict: following the RelationshipType schema, ie with label, description and properties keys potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". diff --git a/src/neo4j_graphrag/experimental/pipeline/types/schema.py b/src/neo4j_graphrag/experimental/pipeline/types/schema.py index 626c99841..3bc8a7446 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/schema.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/schema.py @@ -19,7 +19,7 @@ EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] -"""Types derived from the SchemaEntity and SchemaRelation types, +"""Types derived from the NodeType and RelationshipType types, so the possible types for dict values are: - str (for label and description) - list[dict[str, str]] (for properties) diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 96bcaf8de..24de870fb 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -204,7 +204,7 @@ def format( class SchemaExtractionTemplate(PromptTemplate): DEFAULT_TEMPLATE = """ -You are a top-tier algorithm designed for extracting a labeled property graph schema in +You are a top-tier algorithm designed for extracting a labeled property graph schema in structured formats. Generate a generalized graph schema based on the input text. Identify key entity types, @@ -219,12 +219,12 @@ class SchemaExtractionTemplate(PromptTemplate): 6. Do not create entity types that aren't clearly mentioned in the text. 7. Keep your schema minimal and focused on clearly identifiable patterns in the text. -Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, +Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. Return a valid JSON object that follows this precise structure: {{ - "entities": [ + "node_types": [ {{ "label": "Person", "properties": [ @@ -236,13 +236,13 @@ class SchemaExtractionTemplate(PromptTemplate): }}, ... ], - "relations": [ + "relationship_types": [ {{ "label": "WORKS_FOR" }}, ... ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Company"], ... ] diff --git a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py index bf74470cf..89f5ae62c 100644 --- a/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/experimental/test_kg_builder_pipeline_e2e.py @@ -33,9 +33,9 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -187,49 +187,49 @@ async def test_pipeline_builder_happy_path( pipe_inputs = { "splitter": {"text": harry_potter_text}, "schema": { - "entities": [ - SchemaEntity( + "node_types": [ + NodeType( label="Person", properties=[ - SchemaProperty(name="name", type="STRING"), - SchemaProperty(name="place_of_birth", type="STRING"), - SchemaProperty(name="date_of_birth", type="DATE"), + PropertyType(name="name", type="STRING"), + PropertyType(name="place_of_birth", type="STRING"), + PropertyType(name="date_of_birth", type="DATE"), ], ), - SchemaEntity( + NodeType( label="Organization", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Potion", properties=[ - SchemaProperty(name="name", type="STRING"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="Location", properties=[ - SchemaProperty(name="address", type="STRING"), + PropertyType(name="address", type="STRING"), ], ), ], - "relations": [ - SchemaRelation( + "relationship_types": [ + RelationshipType( label="KNOWS", ), - SchemaRelation( + RelationshipType( label="PART_OF", ), - SchemaRelation( + RelationshipType( label="LED_BY", ), - SchemaRelation( + RelationshipType( label="DRINKS", ), ], - "potential_schema": [ + "patterns": [ ("Person", "KNOWS", "Person"), ("Person", "DRINKS", "Potion"), ("Person", "PART_OF", "Organization"), @@ -356,9 +356,9 @@ async def test_pipeline_builder_failing_chunk_raise( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } with pytest.raises(LLMGenerationError): @@ -434,9 +434,9 @@ async def test_pipeline_builder_failing_chunk_do_not_raise( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } kg_builder_pipeline.get_node_by_name( @@ -575,9 +575,9 @@ async def test_pipeline_builder_two_documents( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } pipe_inputs_2 = { @@ -585,9 +585,9 @@ async def test_pipeline_builder_two_documents( # note: schema not used in this test because # we are mocking the LLM "schema": { - "entities": [], - "relations": [], - "potential_schema": [], + "node_types": (), + "relationship_types": (), + "patterns": (), }, } await kg_builder_pipeline.run(pipe_inputs_1) diff --git a/tests/e2e/experimental/test_simplekgpipeline_e2e.py b/tests/e2e/experimental/test_simplekgpipeline_e2e.py index 3bd72dd41..76c532659 100644 --- a/tests/e2e/experimental/test_simplekgpipeline_e2e.py +++ b/tests/e2e/experimental/test_simplekgpipeline_e2e.py @@ -305,7 +305,7 @@ async def test_pipeline_builder_with_automatic_schema_extraction( # first call - schema extraction response LLMResponse( content="""{ - "entities": [ + "node_types": [ { "label": "Person", "description": "A character in the story", @@ -322,14 +322,14 @@ async def test_pipeline_builder_with_automatic_schema_extraction( ] } ], - "relations": [ + "relationship_types": [ { "label": "LOCATED_AT", "description": "Indicates where a person is located", "properties": [] } ], - "potential_schema": [ + "patterns": [ ["Person", "LOCATED_AT", "Location"] ] }""" diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index f088334da..375fe0174 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -245,14 +245,14 @@ async def test_extractor_no_schema_enforcement() -> None: schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } ], - "relations": [], - "potential_schema": [], + "relationship_types": [], + "patterns": [], }, ) @@ -301,14 +301,14 @@ async def test_extractor_schema_enforcement_invalid_nodes() -> None: schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } ], - "relations": [], - "potential_schema": [], + "relationship_types": [], + "patterns": [], }, ) @@ -336,7 +336,7 @@ async def test_extraction_schema_enforcement_invalid_node_properties() -> None: schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -345,8 +345,8 @@ async def test_extraction_schema_enforcement_invalid_node_properties() -> None: ], } ], - "relations": [], - "potential_schema": [], + "relationship_types": [], + "patterns": [], }, ) @@ -374,7 +374,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> No schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", } @@ -405,7 +405,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> N schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -414,8 +414,8 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> N ], } ], - "relations": [{"label": "LIKES"}], - "potential_schema": [], + "relationship_types": [{"label": "LIKES"}], + "patterns": [], } ) @@ -446,7 +446,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], @@ -456,8 +456,8 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() "properties": [{"name": "name", "type": "STRING"}], }, ], - "relations": [{"label": "LIVES_IN"}], - "potential_schema": [("Person", "LIVES_IN", "City")], + "relationship_types": [{"label": "LIVES_IN"}], + "patterns": [("Person", "LIVES_IN", "City")], } ) @@ -485,19 +485,19 @@ async def test_extractor_schema_enforcement_invalid_relation_properties() -> Non schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } ], - "relations": [ + "relationship_types": [ { "label": "LIKES", "properties": [{"name": "strength", "type": "STRING"}], } ], - "potential_schema": [], + "patterns": [], } ) @@ -528,14 +528,14 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() - schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], } ], - "relations": [{"label": "LIKES"}], - "potential_schema": [("Person", "LIKES", "Person")], + "relationship_types": [{"label": "LIKES"}], + "patterns": [("Person", "LIKES", "Person")], } ) @@ -563,7 +563,7 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non schema = GraphSchema.model_validate( { - "entities": [ + "node_types": [ { "label": "Person", "properties": [{"name": "name", "type": "STRING"}], @@ -573,8 +573,8 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non "properties": [{"name": "name", "type": "STRING"}], }, ], - "relations": [{"label": "LIVES_IN"}], - "potential_schema": [("Person", "LIVES_IN", "City")], + "relationship_types": [{"label": "LIVES_IN"}], + "patterns": [("Person", "LIVES_IN", "City")], } ) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index c86525f56..e8fc670c2 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -15,6 +15,7 @@ from __future__ import annotations import json +from typing import Tuple from unittest.mock import AsyncMock import pytest @@ -22,9 +23,9 @@ from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaProperty, - SchemaRelation, + NodeType, + PropertyType, + RelationshipType, SchemaFromTextExtractor, GraphSchema, ) @@ -37,47 +38,47 @@ @pytest.fixture -def valid_entities() -> tuple[SchemaEntity, ...]: +def valid_node_types() -> tuple[NodeType, ...]: return ( - SchemaEntity( + NodeType( label="PERSON", description="An individual human being.", properties=[ - SchemaProperty(name="birth date", type="ZONED_DATETIME"), - SchemaProperty(name="name", type="STRING"), + PropertyType(name="birth date", type="ZONED_DATETIME"), + PropertyType(name="name", type="STRING"), ], ), - SchemaEntity( + NodeType( label="ORGANIZATION", description="A structured group of people with a common purpose.", ), - SchemaEntity(label="AGE", description="Age of a person in years."), + NodeType(label="AGE", description="Age of a person in years."), ) @pytest.fixture -def valid_relations() -> tuple[SchemaRelation, ...]: +def valid_relationship_types() -> tuple[RelationshipType, ...]: return ( - SchemaRelation( + RelationshipType( label="EMPLOYED_BY", description="Indicates employment relationship.", properties=[ - SchemaProperty(name="start_time", type="LOCAL_DATETIME"), - SchemaProperty(name="end_time", type="LOCAL_DATETIME"), + PropertyType(name="start_time", type="LOCAL_DATETIME"), + PropertyType(name="end_time", type="LOCAL_DATETIME"), ], ), - SchemaRelation( + RelationshipType( label="ORGANIZED_BY", description="Indicates organization responsible for an event.", ), - SchemaRelation( + RelationshipType( label="ATTENDED_BY", description="Indicates attendance at an event." ), ) @pytest.fixture -def potential_schema() -> tuple[tuple[str, str, str], ...]: +def valid_patterns() -> tuple[tuple[str, str, str], ...]: return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("ORGANIZATION", "ATTENDED_BY", "PERSON"), @@ -85,7 +86,7 @@ def potential_schema() -> tuple[tuple[str, str, str], ...]: @pytest.fixture -def potential_schema_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: +def patterns_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: return ( ("PERSON", "EMPLOYED_BY", "ORGANIZATION"), ("NON_EXISTENT_ENTITY", "ATTENDED_BY", "PERSON"), @@ -93,7 +94,7 @@ def potential_schema_with_invalid_entity() -> tuple[tuple[str, str, str], ...]: @pytest.fixture -def potential_schema_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: +def patterns_with_invalid_relation() -> tuple[tuple[str, str, str], ...]: return (("PERSON", "NON_EXISTENT_RELATION", "ORGANIZATION"),) @@ -105,57 +106,57 @@ def schema_builder() -> SchemaBuilder: @pytest.fixture def graph_schema( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity], - valid_relations: tuple[SchemaRelation], - potential_schema: tuple[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> GraphSchema: return schema_builder.create_schema_model( - list(valid_entities), list(valid_relations), list(potential_schema) + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) def test_create_schema_model_valid_data( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity], - valid_relations: tuple[SchemaRelation], - potential_schema: tuple[tuple[str, str, str]], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: schema_instance = schema_builder.create_schema_model( - list(valid_entities), list(valid_relations), list(potential_schema) + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema_instance.entities == valid_entities - assert schema_instance.relations == valid_relations - assert schema_instance.potential_schema == potential_schema + assert schema_instance.node_types == valid_node_types + assert schema_instance.relationship_types == valid_relationship_types + assert schema_instance.patterns == valid_patterns @pytest.mark.asyncio async def test_run_method( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], - valid_relations: tuple[SchemaRelation, ...], - potential_schema: tuple[tuple[str, str, str], ...], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: schema = await schema_builder.run( - list(valid_entities), list(valid_relations), list(potential_schema) + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema.entities == valid_entities - assert schema.relations == valid_relations - assert schema.potential_schema == potential_schema + assert schema.node_types == valid_node_types + assert schema.relationship_types == valid_relationship_types + assert schema.patterns == valid_patterns def test_create_schema_model_invalid_entity( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], - valid_relations: tuple[SchemaRelation, ...], - potential_schema_with_invalid_entity: tuple[tuple[str, str, str], ...], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + patterns_with_invalid_entity: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - list(valid_entities), - list(valid_relations), - list(potential_schema_with_invalid_entity), + list(valid_node_types), + list(valid_relationship_types), + list(patterns_with_invalid_entity), ) assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( exc_info.value @@ -164,15 +165,15 @@ def test_create_schema_model_invalid_entity( def test_create_schema_model_invalid_relation( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], - valid_relations: tuple[SchemaRelation, ...], - potential_schema_with_invalid_relation: tuple[tuple[str, str, str], ...], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], + patterns_with_invalid_relation: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - list(valid_entities), - list(valid_relations), - list(potential_schema_with_invalid_relation), + list(valid_node_types), + list(valid_relationship_types), + list(patterns_with_invalid_relation), ) assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value @@ -181,47 +182,47 @@ def test_create_schema_model_invalid_relation( def test_create_schema_model_no_potential_schema( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], - valid_relations: tuple[SchemaRelation, ...], + valid_node_types: Tuple[NodeType, ...], + valid_relationship_types: Tuple[RelationshipType, ...], ) -> None: schema_instance = schema_builder.create_schema_model( - list(valid_entities), list(valid_relations) + list(valid_node_types), list(valid_relationship_types) ) - assert schema_instance.entities == valid_entities - assert schema_instance.relations == valid_relations - assert schema_instance.potential_schema is None + assert schema_instance.node_types == valid_node_types + assert schema_instance.relationship_types == valid_relationship_types + assert schema_instance.patterns is None def test_create_schema_model_no_relations_or_potential_schema( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], + valid_node_types: Tuple[NodeType, ...], ) -> None: - schema_instance = schema_builder.create_schema_model(list(valid_entities)) + schema_instance = schema_builder.create_schema_model(list(valid_node_types)) - assert len(schema_instance.entities) == 3 - person = schema_instance.entity_from_label("PERSON") + assert len(schema_instance.node_types) == 3 + person = schema_instance.node_type_from_label("PERSON") assert person is not None assert person.description == "An individual human being." assert len(person.properties) == 2 - org = schema_instance.entity_from_label("ORGANIZATION") + org = schema_instance.node_type_from_label("ORGANIZATION") assert org is not None assert org.description == "A structured group of people with a common purpose." - age = schema_instance.entity_from_label("AGE") + age = schema_instance.node_type_from_label("AGE") assert age is not None assert age.description == "Age of a person in years." def test_create_schema_model_missing_relations( schema_builder: SchemaBuilder, - valid_entities: tuple[SchemaEntity, ...], - potential_schema: tuple[tuple[str, str, str], ...], + valid_node_types: Tuple[NodeType, ...], + valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: with pytest.raises(SchemaValidationError) as exc_info: schema_builder.create_schema_model( - entities=valid_entities, potential_schema=potential_schema + node_types=valid_node_types, patterns=valid_patterns ) assert "Relations must also be provided when using a potential schema." in str( exc_info.value @@ -239,7 +240,7 @@ def mock_llm() -> AsyncMock: def valid_schema_json() -> str: return """ { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -253,7 +254,7 @@ def valid_schema_json() -> str: ] } ], - "relations": [ + "relationship_types": [ { "label": "WORKS_FOR", "properties": [ @@ -261,7 +262,7 @@ def valid_schema_json() -> str: ] } ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Organization"] ] } @@ -272,7 +273,7 @@ def valid_schema_json() -> str: def invalid_schema_json() -> str: return """ { - "entities": [ + "node_types": [ { "label": "Person", }, @@ -306,16 +307,16 @@ async def test_schema_from_text_run_valid_response( assert "Sample text for extraction" in prompt_arg # verify the schema was correctly extracted - assert len(schema_config.entities) == 2 - assert schema_config.entity_from_label("Person") is not None - assert schema_config.entity_from_label("Organization") is not None + assert len(schema_config.node_types) == 2 + assert schema_config.node_type_from_label("Person") is not None + assert schema_config.node_type_from_label("Organization") is not None - assert schema_config.relations is not None - assert schema_config.relation_from_label("WORKS_FOR") is not None + assert schema_config.relationship_types is not None + assert schema_config.relationship_type_from_label("WORKS_FOR") is not None - assert schema_config.potential_schema is not None - assert len(schema_config.potential_schema) == 1 - assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + assert schema_config.patterns is not None + assert len(schema_config.patterns) == 1 + assert schema_config.patterns[0] == ("Person", "WORKS_FOR", "Organization") @pytest.mark.asyncio @@ -397,8 +398,8 @@ async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: # verify the content is valid JSON and contains expected data with open(json_path, "r") as f: data = json.load(f) - assert "entities" in data - assert len(data["entities"]) == 3 + assert "node_types" in data + assert len(data["node_types"]) == 3 @pytest.mark.asyncio @@ -417,8 +418,8 @@ async def test_schema_config_store_as_yaml(graph_schema: GraphSchema) -> None: # Verify the content is valid YAML and contains expected data with open(yaml_path, "r") as f: data = yaml.safe_load(f) - assert "entities" in data - assert len(data["entities"]) == 3 + assert "node_types" in data + assert len(data["node_types"]) == 3 @pytest.mark.asyncio @@ -445,9 +446,9 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: assert isinstance(yml_schema, GraphSchema) # verify basic structure is intact - assert "entities" in json_schema.model_dump() - assert "entities" in yaml_schema.model_dump() - assert "entities" in yml_schema.model_dump() + assert "node_types" in json_schema.model_dump() + assert "node_types" in yaml_schema.model_dump() + assert "node_types" in yml_schema.model_dump() # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") @@ -462,7 +463,7 @@ def valid_schema_json_array() -> str: return """ [ { - "entities": [ + "node_types": [ { "label": "Person", "properties": [ @@ -476,7 +477,7 @@ def valid_schema_json_array() -> str: ] } ], - "relations": [ + "relationship_types": [ { "label": "WORKS_FOR", "properties": [ @@ -484,7 +485,7 @@ def valid_schema_json_array() -> str: ] } ], - "potential_schema": [ + "patterns": [ ["Person", "WORKS_FOR", "Organization"] ] } @@ -505,13 +506,13 @@ async def test_schema_from_text_run_valid_json_array( schema = await schema_from_text.run(text="Sample text for extraction") # verify the schema was correctly extracted from the array - assert len(schema.entities) == 2 - assert schema.entity_from_label("Person") is not None - assert schema.entity_from_label("Organization") is not None + assert len(schema.node_types) == 2 + assert schema.node_type_from_label("Person") is not None + assert schema.node_type_from_label("Organization") is not None - assert schema.relations is not None - assert schema.relation_from_label("WORKS_FOR") is not None + assert schema.relationship_types is not None + assert schema.relationship_type_from_label("WORKS_FOR") is not None - assert schema.potential_schema is not None - assert len(schema.potential_schema) == 1 - assert schema.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 1d188a03d..773a84486 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -26,8 +26,8 @@ from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, SchemaFromTextExtractor, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -142,9 +142,9 @@ def test_simple_kg_pipeline_config_schema_run_params() -> None: potential_schema=[("Person", "KNOWS", "Person")], ) assert config._get_run_params_for_schema() == { - "entities": (SchemaEntity(label="Person"),), - "relations": (SchemaRelation(label="KNOWS"),), - "potential_schema": (("Person", "KNOWS", "Person"),), + "node_types": (NodeType(label="Person"),), + "relationship_types": (RelationshipType(label="KNOWS"),), + "patterns": (("Person", "KNOWS", "Person"),), } diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index 2de9f1485..b4ece857a 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -19,8 +19,8 @@ import pytest from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.schema import ( - SchemaEntity, - SchemaRelation, + NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError @@ -118,8 +118,8 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: file_path = "path/to/test.pdf" - internal_entities = [SchemaEntity(label=label) for label in entities] - internal_relations = [SchemaRelation(label=label) for label in relations] + internal_node_types = [NodeType(label=label) for label in entities] + internal_relationship_types = [RelationshipType(label=label) for label in relations] with patch.object( kg_builder.runner.pipeline, @@ -128,9 +128,11 @@ async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: ) as mock_run: await kg_builder.run_async(file_path=file_path) pipe_inputs = mock_run.call_args[1]["data"] - assert pipe_inputs["schema"]["entities"] == tuple(internal_entities) - assert pipe_inputs["schema"]["relations"] == tuple(internal_relations) - assert pipe_inputs["schema"]["potential_schema"] == tuple(potential_schema) + assert pipe_inputs["schema"]["node_types"] == tuple(internal_node_types) + assert pipe_inputs["schema"]["relationship_types"] == tuple( + internal_relationship_types + ) + assert pipe_inputs["schema"]["patterns"] == tuple(potential_schema) def test_simple_kg_pipeline_on_error_invalid_value() -> None: From 5df3efcc46ea13ea4946e4625597e394dfd4d0d5 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 20 May 2025 15:41:41 +0200 Subject: [PATCH 11/16] Fix remaining SchemaConfig --- .../experimental/components/schema.py | 2 +- .../test_entity_relation_extractor.py | 40 ++++++++++--------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index ad1ca4db4..1136c6db3 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -367,7 +367,7 @@ async def run( class SchemaFromTextExtractor(Component): """ - A component for constructing SchemaConfig objects from the output of an LLM after + A component for constructing GraphSchema objects from the output of an LLM after automatic schema extraction from text. """ diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 375fe0174..019e10d2e 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -603,15 +603,17 @@ async def test_extractor_schema_enforcement_none_relationships_in_schema() -> No llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations=None, - potential_schema=None, + schema = GraphSchema.model_validate( + dict( + node_types=[ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + relationship_types=None, + patterns=None, + ) ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -638,15 +640,17 @@ async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> N llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT ) - schema = SchemaConfig( - entities={ - "Person": { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - }, - relations={}, - potential_schema=None, + schema = GraphSchema.model_validate( + dict( + node_types=[ + { + "label": "Person", + "properties": [{"name": "name", "type": "STRING"}], + } + ], + relationship_types=None, + patterns=None, + ) ) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) From 142dca9d15bafb2589a71e1b1c444bb56e6e55a5 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 20 May 2025 15:47:29 +0200 Subject: [PATCH 12/16] Mypy --- .../experimental/components/entity_relation_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 910dd13a1..1d29fefe2 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -441,7 +441,7 @@ def _enforce_relationships( if self.enforce_schema != SchemaEnforcementMode.STRICT: return extracted_relationships - if schema.relations is None: + if schema.relationship_types is None: return extracted_relationships valid_rels = [] From 2f6a4cf28f57493085fc44b960eb0c312a385360 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 20 May 2025 15:59:36 +0200 Subject: [PATCH 13/16] Fix bad test update --- .../experimental/components/test_entity_relation_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 019e10d2e..64ac0d42d 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -648,7 +648,7 @@ async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> N "properties": [{"name": "name", "type": "STRING"}], } ], - relationship_types=None, + relationship_types=[], patterns=None, ) ) From 8c666a0bb1f36a4587851a0d531b55ca2bb700da Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 11:01:48 +0200 Subject: [PATCH 14/16] Fix README/doc/bug --- README.md | 1 + docs/source/user_guide_kg_builder.rst | 2 +- .../template_pipeline/simple_kg_builder.py | 4 +- .../test_simple_kg_builder.py | 165 ++++++++++++++++++ 4 files changed, 169 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a8bd4ccc0..85726c0aa 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ kg_builder = SimpleKGPipeline( schema={ "node_types": node_types, "relationship_types": relationship_types, + "patterns": patterns, }, on_error="IGNORE", from_pdf=False, diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index d29014fca..d7455a6f8 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -102,7 +102,7 @@ as shown below: {"label": "RULES", "properties": [{"name": "fromYear", "type": "INTEGER"}]}, ] -The `patterns` is defined by a list of triplet in the format: +The `patterns` are defined by a list of triplet in the format: `(source_node_label, relationship_label, target_node_label)`. For instance: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index ac4fc04b8..d8ea45fe4 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -220,9 +220,9 @@ def _process_schema_with_precedence( ) relationship_types = tuple( RelationshipType.from_text_or_dict(r) - for r in self.schema_.get("node_types", ()) + for r in self.schema_.get("relationship_types", ()) ) - ps = self.schema_.get("potential_schema") + ps = self.schema_.get("patterns") patterns = tuple(ps) if ps else None else: # use individual components diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 773a84486..7dc331049 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -29,6 +29,7 @@ NodeType, RelationshipType, SchemaFromTextExtractor, + GraphSchema, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -314,3 +315,167 @@ def test_simple_kg_pipeline_config_run_params_both_file_and_text() -> None: "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." in str(excinfo) ) + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> None: + entities = [ + "Person", + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations = [ + "WORKS_FOR", + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + entities=entities, + relations=relations, + potential_schema=potential_schema, + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert len(patterns) == 2 + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_dict() -> None: + entities = [ + "Person", + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations = [ + "WORKS_FOR", + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + schema={ + "node_types": entities, + "relationship_types": relations, + "patterns": potential_schema, + } + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert len(patterns) == 2 + + +def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object() -> ( + None +): + entities = [ + {"label": "Person"}, + { + "label": "Organization", + "description": "A group of persons", + "properties": [ + { + "name": "name", + "type": "STRING", + } + ], + }, + ] + relations = [ + {"label": "WORKS_FOR"}, + { + "label": "CREATED", + "description": "A person created an organization", + "properties": [ + { + "name": "date", + "description": "The date the organization was created", + "type": "DATE", + }, + {"name": "isActive", "type": "BOOLEAN"}, + ], + }, + ] + potential_schema = [ + ("Person", "WORKS_FOR", "Organization"), + ("Person", "CREATED", "Organization"), + ] + config = SimpleKGPipelineConfig( + schema=GraphSchema.model_validate( + { + "node_types": entities, + "relationship_types": relations, + "patterns": potential_schema, + } + ) + ) + node_types, relationship_types, patterns = config._process_schema_with_precedence() + assert len(node_types) == 2 + assert node_types[0].label == "Person" + assert len(node_types[0].properties) == 0 + assert node_types[1].label == "Organization" + assert len(node_types[1].properties) == 1 + assert len(relationship_types) == 2 + assert relationship_types[0].label == "WORKS_FOR" + assert len(relationship_types[0].properties) == 0 + assert relationship_types[1].label == "CREATED" + assert len(relationship_types[1].properties) == 2 + assert len(patterns) == 2 From 854d5f3859096ed06f3e67a68c15a8d94fafc208 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 11:07:34 +0200 Subject: [PATCH 15/16] Mypy --- .../config/template_pipeline/test_simple_kg_builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 7dc331049..4b4988435 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -39,6 +39,8 @@ SimpleKGPipelineConfig, ) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.types.schema import EntityInputType, \ + RelationInputType from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm import LLMInterface @@ -318,7 +320,7 @@ def test_simple_kg_pipeline_config_run_params_both_file_and_text() -> None: def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> None: - entities = [ + entities: list[EntityInputType] = [ "Person", { "label": "Organization", @@ -331,7 +333,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> No ], }, ] - relations = [ + relations: list[RelationInputType] = [ "WORKS_FOR", { "label": "CREATED", @@ -366,6 +368,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> No assert len(relationship_types[0].properties) == 0 assert relationship_types[1].label == "CREATED" assert len(relationship_types[1].properties) == 2 + assert patterns is not None assert len(patterns) == 2 @@ -420,6 +423,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_dict() assert len(relationship_types[0].properties) == 0 assert relationship_types[1].label == "CREATED" assert len(relationship_types[1].properties) == 2 + assert patterns is not None assert len(patterns) == 2 @@ -478,4 +482,5 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object( assert len(relationship_types[0].properties) == 0 assert relationship_types[1].label == "CREATED" assert len(relationship_types[1].properties) == 2 + assert patterns is not None assert len(patterns) == 2 From 5c0a5e9bdb880b7365ce072d327a6de97fb924f8 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 11:11:50 +0200 Subject: [PATCH 16/16] Ruff --- .../config/template_pipeline/test_simple_kg_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 4b4988435..8bcdd1f37 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -39,8 +39,10 @@ SimpleKGPipelineConfig, ) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.types.schema import EntityInputType, \ - RelationInputType +from neo4j_graphrag.experimental.pipeline.types.schema import ( + EntityInputType, + RelationInputType, +) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm import LLMInterface