From afefff8ccec37cc1a9b5f8a14719aa6cefa2cfa3 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 1 Jul 2025 21:41:35 +0200 Subject: [PATCH 1/8] Update default prompt for schema extraction --- src/neo4j_graphrag/generation/prompts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 8c2bc470e..d9045a944 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -207,16 +207,16 @@ class SchemaExtractionTemplate(PromptTemplate): You are a top-tier algorithm designed for extracting a labeled property graph schema in structured formats. -Generate a generalized graph schema based on the input text. Identify key entity types, +Generate a generalized graph schema based on the input text. Identify key node types, their relationship types, and property types. IMPORTANT RULES: 1. Return only abstract schema information, not concrete instances. -2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product). -3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES). +2. Use singular PascalCase labels for node types (e.g., Person, Company, Product). +3. Use UPPER_SNAKE_CASE labels for relationship types (e.g., WORKS_FOR, MANAGES). 4. Include property definitions only when the type can be confidently inferred, otherwise omit them. -5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists. -6. Do not create entity types that aren't clearly mentioned in the text. +5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types. +6. Do not create node types that aren't clearly mentioned in the text. 7. Keep your schema minimal and focused on clearly identifiable patterns in the text. Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, From e5ace2b7e412db662a46672f953e4388760d9a2c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 1 Jul 2025 23:17:00 +0200 Subject: [PATCH 2/8] Filter invalid patterns from extracted schema --- .../experimental/components/schema.py | 85 ++++++++++++++++++- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 87d7ecf93..fc9e490d7 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -430,6 +430,79 @@ def __init__( ) self._llm_params: dict[str, Any] = llm_params or {} + def _filter_invalid_patterns( + self, + patterns: List[Tuple[str, str, str]], + node_types: List[Dict[str, Any]], + relationship_types: Optional[List[Dict[str, Any]]] = None, + ) -> List[Tuple[str, str, str]]: + """ + Filter out patterns that reference undefined node types or relationship types. + + Args: + patterns: List of patterns to filter. + node_types: List of node type definitions. + relationship_types: Optional list of relationship type definitions. + + Returns: + Filtered list of patterns containing only valid references. + """ + # Early returns for missing required types + if not node_types: + logging.warning( + "Filtering out all patterns because no node types are defined. " + "Patterns reference node types that must be defined." + ) + return [] + + if not relationship_types: + logging.warning( + "Filtering out all patterns because no relationship types are defined. " + "GraphSchema validation requires relationship_types when patterns are provided." + ) + return [] + + # Create sets of valid labels + valid_node_labels = { + node_type["label"] for node_type in node_types if node_type.get("label") + } + + valid_relationship_labels = { + rel_type["label"] + for rel_type in relationship_types + if rel_type.get("label") + } + + # Filter patterns + filtered_patterns = [] + for pattern in patterns: + if not (isinstance(pattern, (list, tuple)) and len(pattern) == 3): + continue + + entity1, relation, entity2 = pattern + + # Check if all components are valid + if ( + entity1 in valid_node_labels + and entity2 in valid_node_labels + and relation in valid_relationship_labels + ): + filtered_patterns.append(pattern) + else: + # Log invalid pattern with validation details + entity1_valid = entity1 in valid_node_labels + entity2_valid = entity2 in valid_node_labels + relation_valid = relation in valid_relationship_labels + + logging.warning( + f"Filtering out invalid pattern: {pattern}. " + f"Entity1 '{entity1}' valid: {entity1_valid}, " + f"Entity2 '{entity2}' valid: {entity2_valid}, " + f"Relation '{relation}' valid: {relation_valid}" + ) + + return filtered_patterns + @validate_call async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ @@ -459,13 +532,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema pass # Keep as is # handle list elif isinstance(extracted_schema, list): - if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict): - extracted_schema = extracted_schema[0] - elif len(extracted_schema) == 0: + if len(extracted_schema) == 0: logging.warning( "LLM returned an empty list for schema. Falling back to empty schema." ) extracted_schema = {} + elif isinstance(extracted_schema[0], dict): + extracted_schema = extracted_schema[0] else: raise SchemaExtractionError( f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}" @@ -488,6 +561,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "patterns" ) + # Filter out invalid patterns before validation + if extracted_patterns: + extracted_patterns = self._filter_invalid_patterns( + extracted_patterns, extracted_node_types, extracted_relationship_types + ) + return GraphSchema.model_validate( { "node_types": extracted_node_types, From b761be1229850da8c3ed346db37707e91453240c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 1 Jul 2025 23:17:14 +0200 Subject: [PATCH 3/8] Add unit tests --- .../experimental/components/test_schema.py | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 79e7712d9..5ada2ee09 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -700,3 +700,113 @@ async def test_schema_from_text_run_valid_json_array( assert schema.patterns is not None assert len(schema.patterns) == 1 assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.fixture +def schema_json_with_invalid_node_patterns() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "WORKS_FOR", "UndefinedNode"], + ["UndefinedNode", "WORKS_FOR", "Organization"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_invalid_relationship_patterns() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "UNDEFINED_RELATION", "Organization"], + ["Organization", "ANOTHER_UNDEFINED_RELATION", "Person"] + ] + } + """ + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_invalid_node_patterns( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_invalid_node_patterns: str, +) -> None: + # configure the mock LLM to return schema with invalid node patterns + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_invalid_node_patterns + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that invalid node patterns were filtered out (2 out of 3 patterns should be removed) + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_invalid_relationship_patterns( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_invalid_relationship_patterns: str, +) -> None: + # configure the mock LLM to return schema with invalid relationship patterns + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_invalid_relationship_patterns + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that invalid relationship patterns were filtered out (2 out of 3 patterns should be removed) + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") From 3ca02d31d6ebe5fecca6814987bbe7f41424a6d2 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 1 Jul 2025 23:23:03 +0200 Subject: [PATCH 4/8] Improve documentation --- docs/source/user_guide_kg_builder.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 4da75e4b1..df4ee92e7 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -837,7 +837,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto # Extract the schema from the text extracted_schema = await schema_extractor.run(text="Some text") -The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. +The `SchemaFromTextExtractor` component analyzes the text and identifies node types, relationship types, their property types, and the patterns connecting them. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: From 4e588204cdd6c92e973838d0a3b4cb5e3544c051 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 2 Jul 2025 11:43:57 +0200 Subject: [PATCH 5/8] Change warning to info --- src/neo4j_graphrag/experimental/components/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index fc9e490d7..0adbfbf61 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -494,7 +494,7 @@ def _filter_invalid_patterns( entity2_valid = entity2 in valid_node_labels relation_valid = relation in valid_relationship_labels - logging.warning( + logging.info( f"Filtering out invalid pattern: {pattern}. " f"Entity1 '{entity1}' valid: {entity1_valid}, " f"Entity2 '{entity2}' valid: {entity2_valid}, " From ea022364a1e7032f6986c1f187b72e39b06893b9 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 2 Jul 2025 12:05:06 +0200 Subject: [PATCH 6/8] Correct other logging level --- src/neo4j_graphrag/experimental/components/schema.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 0adbfbf61..17ef53c2e 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -449,14 +449,14 @@ def _filter_invalid_patterns( """ # Early returns for missing required types if not node_types: - logging.warning( + logging.info( "Filtering out all patterns because no node types are defined. " "Patterns reference node types that must be defined." ) return [] if not relationship_types: - logging.warning( + logging.info( "Filtering out all patterns because no relationship types are defined. " "GraphSchema validation requires relationship_types when patterns are provided." ) @@ -533,7 +533,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema # handle list elif isinstance(extracted_schema, list): if len(extracted_schema) == 0: - logging.warning( + logging.info( "LLM returned an empty list for schema. Falling back to empty schema." ) extracted_schema = {} From 820e153b7d7df08abe855d7c2d09d586cb675320 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 2 Jul 2025 12:05:53 +0200 Subject: [PATCH 7/8] Filter out nodes and rels without labels --- .../experimental/components/schema.py | 60 ++++++++++++++++--- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 17ef53c2e..c242d8d1f 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -463,14 +463,9 @@ def _filter_invalid_patterns( return [] # Create sets of valid labels - valid_node_labels = { - node_type["label"] for node_type in node_types if node_type.get("label") - } - + valid_node_labels = {node_type["label"] for node_type in node_types} valid_relationship_labels = { - rel_type["label"] - for rel_type in relationship_types - if rel_type.get("label") + rel_type["label"] for rel_type in relationship_types } # Filter patterns @@ -503,6 +498,50 @@ def _filter_invalid_patterns( return filtered_patterns + def _filter_nodes_without_labels( + self, node_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Filter out node types that have no labels. + + Args: + node_types: List of node type definitions. + + Returns: + Filtered list of node types containing only those with valid labels. + """ + filtered_nodes = [] + for node_type in node_types: + if node_type.get("label"): + filtered_nodes.append(node_type) + else: + logging.info(f"Filtering out node type with missing label: {node_type}") + + return filtered_nodes + + def _filter_relationships_without_labels( + self, relationship_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Filter out relationship types that have no labels. + + Args: + relationship_types: List of relationship type definitions. + + Returns: + Filtered list of relationship types containing only those with valid labels. + """ + filtered_relationships = [] + for rel_type in relationship_types: + if rel_type.get("label"): + filtered_relationships.append(rel_type) + else: + logging.info( + f"Filtering out relationship type with missing label: {rel_type}" + ) + + return filtered_relationships + @validate_call async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ @@ -561,6 +600,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "patterns" ) + # Filter out nodes and relationships without labels + extracted_node_types = self._filter_nodes_without_labels(extracted_node_types) + if extracted_relationship_types: + extracted_relationship_types = self._filter_relationships_without_labels( + extracted_relationship_types + ) + # Filter out invalid patterns before validation if extracted_patterns: extracted_patterns = self._filter_invalid_patterns( From aefd9beefa061a35d03e68ef4bfd6f25b4ddc7b3 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 2 Jul 2025 12:24:33 +0200 Subject: [PATCH 8/8] Add more unit tests --- .../experimental/components/test_schema.py | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 5ada2ee09..0a27208bf 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -772,6 +772,101 @@ def schema_json_with_invalid_relationship_patterns() -> str: """ +@pytest.fixture +def schema_json_with_nodes_without_labels() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_relationships_without_labels() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "label": "", + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "label": "MANAGES", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "MANAGES", "Organization"] + ] + } + """ + + @pytest.mark.asyncio async def test_schema_from_text_filters_invalid_node_patterns( schema_from_text: SchemaFromTextExtractor, @@ -810,3 +905,55 @@ async def test_schema_from_text_filters_invalid_relationship_patterns( assert schema.patterns is not None assert len(schema.patterns) == 1 assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_nodes_without_labels( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_nodes_without_labels: str, +) -> None: + # configure the mock LLM to return schema with nodes without labels + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_nodes_without_labels + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that nodes without labels were filtered out (2 out of 4 nodes should be removed) + assert len(schema.node_types) == 2 + assert schema.node_type_from_label("Person") is not None + assert schema.node_type_from_label("Organization") is not None + + # verify that the pattern is still valid with the remaining nodes + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_relationships_without_labels( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_relationships_without_labels: str, +) -> None: + # configure the mock LLM to return schema with relationships without labels + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_relationships_without_labels + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that relationships without labels were filtered out (2 out of 4 relationships should be removed) + assert schema.relationship_types is not None + assert len(schema.relationship_types) == 2 + assert schema.relationship_type_from_label("WORKS_FOR") is not None + assert schema.relationship_type_from_label("MANAGES") is not None + + # verify that the patterns are still valid with the remaining relationships + assert schema.patterns is not None + assert len(schema.patterns) == 2 + assert ("Person", "WORKS_FOR", "Organization") in schema.patterns + assert ("Person", "MANAGES", "Organization") in schema.patterns