diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 4da75e4b1..df4ee92e7 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -837,7 +837,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto # Extract the schema from the text extracted_schema = await schema_extractor.run(text="Some text") -The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. +The `SchemaFromTextExtractor` component analyzes the text and identifies node types, relationship types, their property types, and the patterns connecting them. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 87d7ecf93..c242d8d1f 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -430,6 +430,118 @@ def __init__( ) self._llm_params: dict[str, Any] = llm_params or {} + def _filter_invalid_patterns( + self, + patterns: List[Tuple[str, str, str]], + node_types: List[Dict[str, Any]], + relationship_types: Optional[List[Dict[str, Any]]] = None, + ) -> List[Tuple[str, str, str]]: + """ + Filter out patterns that reference undefined node types or relationship types. + + Args: + patterns: List of patterns to filter. + node_types: List of node type definitions. + relationship_types: Optional list of relationship type definitions. + + Returns: + Filtered list of patterns containing only valid references. + """ + # Early returns for missing required types + if not node_types: + logging.info( + "Filtering out all patterns because no node types are defined. " + "Patterns reference node types that must be defined." + ) + return [] + + if not relationship_types: + logging.info( + "Filtering out all patterns because no relationship types are defined. " + "GraphSchema validation requires relationship_types when patterns are provided." + ) + return [] + + # Create sets of valid labels + valid_node_labels = {node_type["label"] for node_type in node_types} + valid_relationship_labels = { + rel_type["label"] for rel_type in relationship_types + } + + # Filter patterns + filtered_patterns = [] + for pattern in patterns: + if not (isinstance(pattern, (list, tuple)) and len(pattern) == 3): + continue + + entity1, relation, entity2 = pattern + + # Check if all components are valid + if ( + entity1 in valid_node_labels + and entity2 in valid_node_labels + and relation in valid_relationship_labels + ): + filtered_patterns.append(pattern) + else: + # Log invalid pattern with validation details + entity1_valid = entity1 in valid_node_labels + entity2_valid = entity2 in valid_node_labels + relation_valid = relation in valid_relationship_labels + + logging.info( + f"Filtering out invalid pattern: {pattern}. " + f"Entity1 '{entity1}' valid: {entity1_valid}, " + f"Entity2 '{entity2}' valid: {entity2_valid}, " + f"Relation '{relation}' valid: {relation_valid}" + ) + + return filtered_patterns + + def _filter_nodes_without_labels( + self, node_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Filter out node types that have no labels. + + Args: + node_types: List of node type definitions. + + Returns: + Filtered list of node types containing only those with valid labels. + """ + filtered_nodes = [] + for node_type in node_types: + if node_type.get("label"): + filtered_nodes.append(node_type) + else: + logging.info(f"Filtering out node type with missing label: {node_type}") + + return filtered_nodes + + def _filter_relationships_without_labels( + self, relationship_types: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Filter out relationship types that have no labels. + + Args: + relationship_types: List of relationship type definitions. + + Returns: + Filtered list of relationship types containing only those with valid labels. + """ + filtered_relationships = [] + for rel_type in relationship_types: + if rel_type.get("label"): + filtered_relationships.append(rel_type) + else: + logging.info( + f"Filtering out relationship type with missing label: {rel_type}" + ) + + return filtered_relationships + @validate_call async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema: """ @@ -459,13 +571,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema pass # Keep as is # handle list elif isinstance(extracted_schema, list): - if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict): - extracted_schema = extracted_schema[0] - elif len(extracted_schema) == 0: - logging.warning( + if len(extracted_schema) == 0: + logging.info( "LLM returned an empty list for schema. Falling back to empty schema." ) extracted_schema = {} + elif isinstance(extracted_schema[0], dict): + extracted_schema = extracted_schema[0] else: raise SchemaExtractionError( f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}" @@ -488,6 +600,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema "patterns" ) + # Filter out nodes and relationships without labels + extracted_node_types = self._filter_nodes_without_labels(extracted_node_types) + if extracted_relationship_types: + extracted_relationship_types = self._filter_relationships_without_labels( + extracted_relationship_types + ) + + # Filter out invalid patterns before validation + if extracted_patterns: + extracted_patterns = self._filter_invalid_patterns( + extracted_patterns, extracted_node_types, extracted_relationship_types + ) + return GraphSchema.model_validate( { "node_types": extracted_node_types, diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 8c2bc470e..d9045a944 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -207,16 +207,16 @@ class SchemaExtractionTemplate(PromptTemplate): You are a top-tier algorithm designed for extracting a labeled property graph schema in structured formats. -Generate a generalized graph schema based on the input text. Identify key entity types, +Generate a generalized graph schema based on the input text. Identify key node types, their relationship types, and property types. IMPORTANT RULES: 1. Return only abstract schema information, not concrete instances. -2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product). -3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES). +2. Use singular PascalCase labels for node types (e.g., Person, Company, Product). +3. Use UPPER_SNAKE_CASE labels for relationship types (e.g., WORKS_FOR, MANAGES). 4. Include property definitions only when the type can be confidently inferred, otherwise omit them. -5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists. -6. Do not create entity types that aren't clearly mentioned in the text. +5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types. +6. Do not create node types that aren't clearly mentioned in the text. 7. Keep your schema minimal and focused on clearly identifiable patterns in the text. Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 79e7712d9..0a27208bf 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -700,3 +700,260 @@ async def test_schema_from_text_run_valid_json_array( assert schema.patterns is not None assert len(schema.patterns) == 1 assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.fixture +def schema_json_with_invalid_node_patterns() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "WORKS_FOR", "UndefinedNode"], + ["UndefinedNode", "WORKS_FOR", "Organization"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_invalid_relationship_patterns() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "UNDEFINED_RELATION", "Organization"], + ["Organization", "ANOTHER_UNDEFINED_RELATION", "Person"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_nodes_without_labels() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + """ + + +@pytest.fixture +def schema_json_with_relationships_without_labels() -> str: + return """ + { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relationship_types": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "label": "", + "properties": [ + {"name": "since", "type": "DATE"} + ] + }, + { + "label": "MANAGES", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "patterns": [ + ["Person", "WORKS_FOR", "Organization"], + ["Person", "MANAGES", "Organization"] + ] + } + """ + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_invalid_node_patterns( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_invalid_node_patterns: str, +) -> None: + # configure the mock LLM to return schema with invalid node patterns + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_invalid_node_patterns + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that invalid node patterns were filtered out (2 out of 3 patterns should be removed) + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_invalid_relationship_patterns( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_invalid_relationship_patterns: str, +) -> None: + # configure the mock LLM to return schema with invalid relationship patterns + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_invalid_relationship_patterns + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that invalid relationship patterns were filtered out (2 out of 3 patterns should be removed) + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_nodes_without_labels( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_nodes_without_labels: str, +) -> None: + # configure the mock LLM to return schema with nodes without labels + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_nodes_without_labels + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that nodes without labels were filtered out (2 out of 4 nodes should be removed) + assert len(schema.node_types) == 2 + assert schema.node_type_from_label("Person") is not None + assert schema.node_type_from_label("Organization") is not None + + # verify that the pattern is still valid with the remaining nodes + assert schema.patterns is not None + assert len(schema.patterns) == 1 + assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_filters_relationships_without_labels( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + schema_json_with_relationships_without_labels: str, +) -> None: + # configure the mock LLM to return schema with relationships without labels + mock_llm.ainvoke.return_value = LLMResponse( + content=schema_json_with_relationships_without_labels + ) + + # run the schema extraction + schema = await schema_from_text.run(text="Sample text for extraction") + + # verify that relationships without labels were filtered out (2 out of 4 relationships should be removed) + assert schema.relationship_types is not None + assert len(schema.relationship_types) == 2 + assert schema.relationship_type_from_label("WORKS_FOR") is not None + assert schema.relationship_type_from_label("MANAGES") is not None + + # verify that the patterns are still valid with the remaining relationships + assert schema.patterns is not None + assert len(schema.patterns) == 2 + assert ("Person", "WORKS_FOR", "Organization") in schema.patterns + assert ("Person", "MANAGES", "Organization") in schema.patterns