diff --git a/CHANGELOG.md b/CHANGELOG.md index 53c88b143..df446f7da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,15 @@ ## Next +### Added + +- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. + ### Fixed - Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers. + ## 1.7.0 ### Added diff --git a/docs/source/api.rst b/docs/source/api.rst index e895cd5dd..55a5d1cc4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -77,6 +77,12 @@ SchemaBuilder .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaBuilder :members: run +SchemaFromTextExtractor +----------------------- + +.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromTextExtractor + :members: run + EntityRelationExtractor ======================= @@ -362,6 +368,13 @@ ERExtractionTemplate :members: :exclude-members: format +SchemaExtractionTemplate +------------------------ + +.. autoclass:: neo4j_graphrag.generation.prompts.SchemaExtractionTemplate + :members: + :exclude-members: format + Text2CypherTemplate -------------------- diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 11a6d1741..30d478667 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -21,7 +21,7 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of - **Data loader**: extract text from files (PDFs, ...). - **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit). - **Chunk embedder** (optional): compute the chunk embeddings. -- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. +- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. - **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional). - **Entity and relation extractor**: extract relevant entities and relations from the text. - **Knowledge Graph writer**: save the identified entities and relations. @@ -75,10 +75,11 @@ Graph Schema It is possible to guide the LLM by supplying a list of entities, relationships, and instructions on how to connect them. However, note that the extracted graph -may not fully adhere to these guidelines. Entities and relationships can be -represented as either simple strings (for their labels) or dictionaries. If using -a dictionary, it must include a label key and can optionally include description -and properties keys, as shown below: +may not fully adhere to these guidelines unless schema enforcement is enabled +(see :ref:`Schema Enforcement Behaviour`). Entities and relationships can be represented +as either simple strings (for their labels) or dictionaries. If using a dictionary, +it must include a label key and can optionally include description and properties keys, +as shown below: .. code:: python @@ -117,14 +118,20 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated .. code:: python + # Using the schema parameter (recommended approach) kg_builder = SimpleKGPipeline( # ... - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, + schema={ + "entities": ENTITIES, + "relations": RELATIONS, + "potential_schema": POTENTIAL_SCHEMA + }, # ... ) +.. note:: + By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromTextExtractor`). + Extra configurations -------------------- @@ -412,41 +419,44 @@ within the configuration file. "neo4j_database": "myDb", "on_error": "IGNORE", "prompt_template": "...", - "entities": [ - "Person", - { - "label": "House", - "description": "Family the person belongs to", - "properties": [ - {"name": "name", "type": "STRING"} - ] - }, - { - "label": "Planet", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "weather", "type": "STRING"} - ] - } - ], - "relations": [ - "PARENT_OF", - { - "label": "HEIR_OF", - "description": "Used for inheritor relationship between father and sons" - }, - { - "label": "RULES", - "properties": [ - {"name": "fromYear", "type": "INTEGER"} - ] - } - ], - "potential_schema": [ - ["Person", "PARENT_OF", "Person"], - ["Person", "HEIR_OF", "House"], - ["House", "RULES", "Planet"] - ], + + "schema": { + "entities": [ + "Person", + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Planet", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "weather", "type": "STRING"} + ] + } + ], + "relations": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + {"name": "fromYear", "type": "INTEGER"} + ] + } + ], + "potential_schema": [ + ["Person", "PARENT_OF", "Person"], + ["Person", "HEIR_OF", "House"], + ["House", "RULES", "Planet"] + ] + }, "lexical_graph_config": { "chunk_node_label": "TextPart" } @@ -462,31 +472,32 @@ or in YAML: neo4j_database: myDb on_error: IGNORE prompt_template: ... - entities: - - label: Person - - label: House - description: Family the person belongs to - properties: - - name: name - type: STRING - - label: Planet - properties: - - name: name - type: STRING - - name: weather - type: STRING - relations: - - label: PARENT_OF - - label: HEIR_OF - description: Used for inheritor relationship between father and sons - - label: RULES - properties: - - name: fromYear - type: INTEGER - potential_schema: - - ["Person", "PARENT_OF", "Person"] - - ["Person", "HEIR_OF", "House"] - - ["House", "RULES", "Planet"] + schema: + entities: + - Person + - label: House + description: Family the person belongs to + properties: + - name: name + type: STRING + - label: Planet + properties: + - name: name + type: STRING + - name: weather + type: STRING + relations: + - PARENT_OF + - label: HEIR_OF + description: Used for inheritor relationship between father and sons + - label: RULES + properties: + - name: fromYear + type: INTEGER + potential_schema: + - ["Person", "PARENT_OF", "Person"] + - ["Person", "HEIR_OF", "House"] + - ["House", "RULES", "Planet"] lexical_graph_config: chunk_node_label: TextPart @@ -791,6 +802,44 @@ Here is a code block illustrating these concepts: After validation, this schema is saved in a `SchemaConfig` object, whose dict representation is passed to the LLM. +Automatic Schema Extraction +--------------------------- + +Instead of manually defining the schema, you can use the `SchemaFromTextExtractor` component to automatically extract a schema from your text using an LLM: + +.. code:: python + + from neo4j_graphrag.experimental.components.schema import SchemaFromTextExtractor + from neo4j_graphrag.llm import OpenAILLM + + # Instantiate the automatic schema extractor component + schema_extractor = SchemaFromTextExtractor( + llm=OpenAILLM( + model_name="gpt-4o", + model_params={ + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + }, + ) + ) + + # Extract the schema from the text + extracted_schema = await schema_extractor.run(text="Some text") + +The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. + +You can also save and reload the extracted schema: + +.. code:: python + + # Save the schema to JSON or YAML files + schema_config.store_as_json("my_schema.json") + schema_config.store_as_yaml("my_schema.yaml") + + # Later, reload the schema from file + from neo4j_graphrag.experimental.components.schema import SchemaConfig + restored_schema = SchemaConfig.from_file("my_schema.json") # or my_schema.yaml + Entity and Relation Extractor ============================= @@ -832,6 +881,8 @@ The LLM to use can be customized, the only constraint is that it obeys the :ref: Schema Enforcement Behaviour ---------------------------- +.. _schema-enforcement-behaviour: + By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: diff --git a/examples/README.md b/examples/README.md index 7feb71f3a..fa8bb945e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,6 +3,7 @@ This folder contains examples usage for the different features supported by the `neo4j-graphrag` package: +- [Automatic Schema Extraction](#schema-extraction) from PDF or text - [Build Knowledge Graph](#build-knowledge-graph) from PDF or text - [Retrieve](#retrieve) information from the graph - [Question Answering](#answer-graphrag) (Q&A) @@ -122,6 +123,7 @@ are listed in [the last section of this file](#customize). - [Chunk embedder]() - Schema Builder: - [User-defined](./customize/build_graph/components/schema_builders/schema.py) + - [Automatic schema extraction](./automatic_schema_extraction/schema_from_text.py) - Entity Relation Extractor: - [LLM-based](./customize/build_graph/components/extractors/llm_entity_relation_extractor.py) - [LLM-based with custom prompt](./customize/build_graph/components/extractors/llm_entity_relation_extractor_with_custom_prompt.py) diff --git a/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py new file mode 100644 index 000000000..639f0b93d --- /dev/null +++ b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py @@ -0,0 +1,97 @@ +"""This example demonstrates how to use SimpleKGPipeline with automatic schema extraction +from a PDF file. When no schema is provided to SimpleKGPipeline, automatic schema extraction +is performed using the LLM. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +import os +from pathlib import Path +from dotenv import load_dotenv +import neo4j + +from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.embeddings import OpenAIEmbeddings + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) + +# PDF file path +root_dir = Path(__file__).parents[2] +PDF_FILE = str( + root_dir / "data" / "Harry Potter and the Chamber of Secrets Summary.pdf" +) + + +async def run_kg_pipeline_with_auto_schema() -> None: + """Run the SimpleKGPipeline with automatic schema extraction from a PDF file.""" + + # Define Neo4j connection + uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") + user = os.getenv("NEO4J_USER", "neo4j") + password = os.getenv("NEO4J_PASSWORD", "password") + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Initialize the Neo4j driver + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + # Create the embedder instance + embedder = OpenAIEmbeddings() + + try: + # Create a SimpleKGPipeline instance without providing a schema + # This will trigger automatic schema extraction + kg_builder = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=True, + ) + + print(f"Processing PDF file: {PDF_FILE}") + # Run the pipeline on the PDF file + await kg_builder.run_async(file_path=PDF_FILE) + + finally: + # Close connections + await llm.async_client.close() + driver.close() + + +async def main() -> None: + """Run the example.""" + # Create data directory if it doesn't exist + data_dir = root_dir / "data" + data_dir.mkdir(exist_ok=True) + + # Check if the PDF file exists + if not Path(PDF_FILE).exists(): + print(f"Warning: PDF file not found at {PDF_FILE}") + print("Please replace with a valid PDF file path.") + return + + # Run the pipeline + await run_kg_pipeline_with_auto_schema() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py new file mode 100644 index 000000000..75b076306 --- /dev/null +++ b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py @@ -0,0 +1,102 @@ +"""This example demonstrates how to use SimpleKGPipeline with automatic schema extraction +from a text input. When no schema is provided to SimpleKGPipeline, automatic schema extraction +is performed using the LLM. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +import os +from dotenv import load_dotenv +import neo4j + +from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.embeddings import OpenAIEmbeddings + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) + +# Sample text to build a knowledge graph from +TEXT = """ +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets +for the consumer electronics industry. + +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +and has filed 5 patents during her time at Acme. + +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports +directly to the CEO, Michael Brown, who took over leadership in 2008. + +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led +by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. + +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 +with an estimated valuation of $500 million. +""" + + +async def run_kg_pipeline_with_auto_schema() -> None: + """Run the SimpleKGPipeline with automatic schema extraction from text input.""" + + # Define Neo4j connection + uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") + user = os.getenv("NEO4J_USER", "neo4j") + password = os.getenv("NEO4J_PASSWORD", "password") + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Initialize the Neo4j driver + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + # Create the embedder instance + embedder = OpenAIEmbeddings() + + try: + # Create a SimpleKGPipeline instance without providing a schema + # This will trigger automatic schema extraction + kg_builder = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=False, # Using raw text input, not PDF + ) + + # Run the pipeline on the text + await kg_builder.run_async(text=TEXT) + + finally: + # Close connections + await llm.async_client.close() + driver.close() + + +async def main() -> None: + """Run the example.""" + await run_kg_pipeline_with_auto_schema() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py new file mode 100644 index 000000000..4396f3fb6 --- /dev/null +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -0,0 +1,130 @@ +"""This example demonstrates how to use the SchemaFromTextExtractor component +to automatically extract a schema from text and save it to JSON and YAML files. + +The SchemaFromTextExtractor component uses an LLM to analyze the text and identify entities, +relations, and their properties. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +from pathlib import Path +from dotenv import load_dotenv + +from neo4j_graphrag.experimental.components.schema import ( + SchemaFromTextExtractor, + SchemaConfig, +) +from neo4j_graphrag.llm import OpenAILLM + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) + +# Sample text to extract schema from - it's about a company and its employees +TEXT = """ +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets +for the consumer electronics industry. + +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +and has filed 5 patents during her time at Acme. + +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports +directly to the CEO, Michael Brown, who took over leadership in 2008. + +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led +by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. + +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 +with an estimated valuation of $500 million. +""" + +# Define the file paths for saving the schema +root_dir = Path(__file__).parents[4] +OUTPUT_DIR = str(root_dir / "data") +JSON_FILE_PATH = str(root_dir / "data" / "extracted_schema.json") +YAML_FILE_PATH = str(root_dir / "data" / "extracted_schema.yaml") + + +async def extract_and_save_schema() -> None: + """Extract schema from text and save it to JSON and YAML files.""" + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + try: + # Create a SchemaFromTextExtractor component with the default template + schema_extractor = SchemaFromTextExtractor(llm=llm) + + print("Extracting schema from text...") + # Extract schema from text + inferred_schema = await schema_extractor.run(text=TEXT) + + # Ensure the output directory exists + Path(OUTPUT_DIR).mkdir(exist_ok=True) + + print(f"Saving schema to JSON file: {JSON_FILE_PATH}") + # Save the schema to JSON file + inferred_schema.store_as_json(JSON_FILE_PATH) + + print(f"Saving schema to YAML file: {YAML_FILE_PATH}") + # Save the schema to YAML file + inferred_schema.store_as_yaml(YAML_FILE_PATH) + + print("\nExtracted Schema Summary:") + print(f"Entities: {list(inferred_schema.entities.keys())}") + print( + f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}" + ) + + if inferred_schema.potential_schema: + print("\nPotential Schema:") + for entity1, relation, entity2 in inferred_schema.potential_schema: + print(f" {entity1} --[{relation}]--> {entity2}") + + finally: + # Close the LLM client + await llm.async_client.close() + + +async def main() -> None: + """Run the example.""" + + # extract schema and save to files + await extract_and_save_schema() + + print("\nSchema files have been saved to:") + print(f" - JSON: {JSON_FILE_PATH}") + print(f" - YAML: {YAML_FILE_PATH}") + + # load schema from files + print("\nLoading schemas from saved files:") + schema_from_json = SchemaConfig.from_file(JSON_FILE_PATH) + schema_from_yaml = SchemaConfig.from_file(YAML_FILE_PATH) + + print(f"Entities in JSON schema: {list(schema_from_json.entities.keys())}") + print(f"Entities in YAML schema: {list(schema_from_yaml.entities.keys())}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/data/extracted_schema.json b/examples/data/extracted_schema.json new file mode 100644 index 000000000..0ec66639b --- /dev/null +++ b/examples/data/extracted_schema.json @@ -0,0 +1,137 @@ +{ + "entities": { + "Company": { + "label": "Company", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "foundedYear", + "type": "INTEGER", + "description": "" + }, + { + "name": "revenue", + "type": "FLOAT", + "description": "" + }, + { + "name": "valuation", + "type": "FLOAT", + "description": "" + } + ] + }, + "Person": { + "label": "Person", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "position", + "type": "STRING", + "description": "" + }, + { + "name": "yearJoined", + "type": "INTEGER", + "description": "" + } + ] + }, + "Product": { + "label": "Product", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "launchYear", + "type": "INTEGER", + "description": "" + }, + { + "name": "unitsSold", + "type": "INTEGER", + "description": "" + } + ] + }, + "Office": { + "label": "Office", + "description": "", + "properties": [ + { + "name": "location", + "type": "STRING", + "description": "" + } + ] + } + }, + "relations": { + "FOUNDED_BY": { + "label": "FOUNDED_BY", + "description": "", + "properties": [] + }, + "WORKS_FOR": { + "label": "WORKS_FOR", + "description": "", + "properties": [] + }, + "MANAGES": { + "label": "MANAGES", + "description": "", + "properties": [] + }, + "DEVELOPED_BY": { + "label": "DEVELOPED_BY", + "description": "", + "properties": [] + }, + "LOCATED_IN": { + "label": "LOCATED_IN", + "description": "", + "properties": [] + } + }, + "potential_schema": [ + [ + "Company", + "FOUNDED_BY", + "Person" + ], + [ + "Person", + "WORKS_FOR", + "Company" + ], + [ + "Person", + "MANAGES", + "Office" + ], + [ + "Product", + "DEVELOPED_BY", + "Person" + ], + [ + "Company", + "LOCATED_IN", + "Office" + ] + ] +} \ No newline at end of file diff --git a/examples/data/extracted_schema.yaml b/examples/data/extracted_schema.yaml new file mode 100644 index 000000000..f2500799f --- /dev/null +++ b/examples/data/extracted_schema.yaml @@ -0,0 +1,87 @@ +entities: + Company: + label: Company + description: '' + properties: + - name: name + type: STRING + description: '' + - name: foundedYear + type: INTEGER + description: '' + - name: revenue + type: FLOAT + description: '' + - name: valuation + type: FLOAT + description: '' + Person: + label: Person + description: '' + properties: + - name: name + type: STRING + description: '' + - name: position + type: STRING + description: '' + - name: yearJoined + type: INTEGER + description: '' + Product: + label: Product + description: '' + properties: + - name: name + type: STRING + description: '' + - name: launchYear + type: INTEGER + description: '' + - name: unitsSold + type: INTEGER + description: '' + Office: + label: Office + description: '' + properties: + - name: location + type: STRING + description: '' +relations: + FOUNDED_BY: + label: FOUNDED_BY + description: '' + properties: [] + WORKS_FOR: + label: WORKS_FOR + description: '' + properties: [] + MANAGES: + label: MANAGES + description: '' + properties: [] + DEVELOPED_BY: + label: DEVELOPED_BY + description: '' + properties: [] + LOCATED_IN: + label: LOCATED_IN + description: '' + properties: [] +potential_schema: +- - Company + - FOUNDED_BY + - Person +- - Person + - WORKS_FOR + - Company +- - Person + - MANAGES + - Office +- - Product + - DEVELOPED_BY + - Person +- - Company + - LOCATED_IN + - Office diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 3c0fdc0b3..681b20eec 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -116,6 +116,12 @@ class SchemaValidationError(Neo4jGraphRagError): pass +class SchemaExtractionError(Neo4jGraphRagError): + """Exception raised for errors in automatic schema extraction.""" + + pass + + class PdfLoaderError(Neo4jGraphRagError): """Custom exception for errors in PDF loader.""" diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 96d7d466b..a58b0b105 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -14,17 +14,27 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, Tuple +import json +import yaml +import logging +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from pathlib import Path from pydantic import BaseModel, ValidationError, model_validator, validate_call from typing_extensions import Self -from neo4j_graphrag.exceptions import SchemaValidationError +from neo4j_graphrag.exceptions import ( + SchemaValidationError, + LLMGenerationError, + SchemaExtractionError, +) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) +from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate +from neo4j_graphrag.llm import LLMInterface class SchemaProperty(BaseModel): @@ -123,6 +133,98 @@ def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: return data + def store_as_json(self, file_path: str) -> None: + """ + Save the schema configuration to a JSON file. + + Args: + file_path (str): The path where the schema configuration will be saved. + """ + with open(file_path, "w") as f: + json.dump(self.model_dump(), f, indent=2) + + def store_as_yaml(self, file_path: str) -> None: + """ + Save the schema configuration to a YAML file. + + Args: + file_path (str): The path where the schema configuration will be saved. + """ + # create a copy of the data and convert tuples to lists for YAML compatibility + data = self.model_dump() + if data.get("potential_schema"): + data["potential_schema"] = [list(item) for item in data["potential_schema"]] + + with open(file_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + + @classmethod + def from_file(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a file (either JSON or YAML). + + The file format is automatically detected based on the file extension. + + Args: + file_path (Union[str, Path]): The path to the schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Schema file not found: {file_path}") + + if file_path.suffix.lower() in [".json"]: + return cls.from_json(file_path) + elif file_path.suffix.lower() in [".yaml", ".yml"]: + return cls.from_yaml(file_path) + else: + raise ValueError( + f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" + ) + + @classmethod + def from_json(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a JSON file. + + Args: + file_path (Union[str, Path]): The path to the JSON schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + with open(file_path, "r") as f: + try: + data = json.load(f) + return cls.model_validate(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON file: {e}") + except ValidationError as e: + raise SchemaValidationError(f"Schema validation failed: {e}") + + @classmethod + def from_yaml(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a YAML file. + + Args: + file_path (Union[str, Path]): The path to the YAML schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + with open(file_path, "r") as f: + try: + data = yaml.safe_load(f) + return cls.model_validate(data) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML file: {e}") + except ValidationError as e: + raise SchemaValidationError(f"Schema validation failed: {e}") + class SchemaBuilder(Component): """ @@ -236,3 +338,128 @@ async def run( SchemaConfig: A configured schema object, constructed asynchronously. """ return self.create_schema_model(entities, relations, potential_schema) + + +class SchemaFromTextExtractor(Component): + """ + A component for constructing SchemaConfig objects from the output of an LLM after + automatic schema extraction from text. + """ + + def __init__( + self, + llm: LLMInterface, + prompt_template: Optional[PromptTemplate] = None, + llm_params: Optional[Dict[str, Any]] = None, + ) -> None: + self._llm: LLMInterface = llm + self._prompt_template: PromptTemplate = ( + prompt_template or SchemaExtractionTemplate() + ) + self._llm_params: dict[str, Any] = llm_params or {} + + @validate_call + async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfig: + """ + Asynchronously extracts the schema from text and returns a SchemaConfig object. + + Args: + text (str): the text from which the schema will be inferred. + examples (str): examples to guide schema extraction. + Returns: + SchemaConfig: A configured schema object, extracted automatically and + constructed asynchronously. + """ + prompt: str = self._prompt_template.format(text=text, examples=examples) + + try: + response = await self._llm.ainvoke(prompt, **self._llm_params) + content: str = response.content + except LLMGenerationError as e: + # Re-raise the LLMGenerationError + raise LLMGenerationError("Failed to generate schema from text") from e + + try: + extracted_schema: Dict[str, Any] = json.loads(content) + + # handle dictionary + if isinstance(extracted_schema, dict): + pass # Keep as is + # handle list + elif isinstance(extracted_schema, list): + if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict): + extracted_schema = extracted_schema[0] + elif len(extracted_schema) == 0: + logging.warning( + "LLM returned an empty list for schema. Falling back to empty schema." + ) + extracted_schema = {} + else: + raise SchemaExtractionError( + f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}" + ) + # any other types + else: + raise SchemaExtractionError( + f"Unexpected schema format returned from LLM: {type(extracted_schema)}. Expected a dictionary or list of dictionaries." + ) + except json.JSONDecodeError as exc: + raise SchemaExtractionError("LLM response is not valid JSON.") from exc + + extracted_entities: List[Dict[str, Any]] = ( + extracted_schema.get("entities") or [] + ) + extracted_relations: Optional[List[Dict[str, Any]]] = extracted_schema.get( + "relations" + ) + potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( + "potential_schema" + ) + + try: + entities: List[SchemaEntity] = [ + SchemaEntity(**e) for e in extracted_entities + ] + relations: Optional[List[SchemaRelation]] = ( + [SchemaRelation(**r) for r in extracted_relations] + if extracted_relations + else None + ) + except ValidationError as exc: + raise SchemaValidationError( + f"Invalid schema format return from LLM: {exc}" + ) from exc + + return SchemaBuilder.create_schema_model( + entities=entities, + relations=relations, + potential_schema=potential_schema, + ) + + +def normalize_schema_dict(schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize a user-provided schema dictionary to the canonical format expected by the pipeline. + + - Converts 'entities' and 'relations' from lists (of strings, dicts, or model objects) to dicts keyed by label. + - Ensures required keys ('entities', 'relations', 'potential_schema') are present. + - Does not mutate the input; returns a new dict. + + Args: + schema_dict (dict): The user-provided schema dictionary, possibly with lists or missing keys. + + Returns: + dict: A normalized schema dictionary with the correct structure for pipeline and Pydantic validation. + """ + norm_schema_dict = dict(schema_dict) + for key, cls in [("entities", SchemaEntity), ("relations", SchemaRelation)]: + if key in norm_schema_dict and isinstance(norm_schema_dict[key], list): + norm_schema_dict[key] = { + cls.from_text_or_dict(e).label: cls.from_text_or_dict(e).model_dump() # type: ignore[attr-defined] + for e in norm_schema_dict[key] + } + if "relations" not in norm_schema_dict: + norm_schema_dict["relations"] = {} + if "potential_schema" not in norm_schema_dict: + norm_schema_dict["potential_schema"] = None + return norm_schema_dict diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 306c4eb32..0df0f61e6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -12,9 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, Literal, Optional, Sequence, Union +from typing import ( + Any, + ClassVar, + Literal, + Optional, + Sequence, + Union, + List, + Tuple, + Dict, + cast, +) +import logging +import warnings -from pydantic import ConfigDict +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -30,8 +44,11 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, + SchemaConfig, SchemaEntity, SchemaRelation, + SchemaFromTextExtractor, + normalize_schema_dict, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -54,6 +71,8 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate +logger = logging.getLogger(__name__) + class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ @@ -74,6 +93,9 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None + schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field( + default=None, alias="schema" + ) enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -87,6 +109,57 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) + @model_validator(mode="before") + def normalize_schema_field(cls, data: Dict[str, Any]) -> Dict[str, Any]: + # Normalize the 'schema' field if it is a dict + schema = data.get("schema") + if isinstance(schema, dict): + data["schema"] = normalize_schema_dict(schema) + return data + + @model_validator(mode="after") + def handle_schema_precedence(self) -> Self: + """Handle schema precedence and warnings""" + self._process_schema_parameters() + return self + + def _process_schema_parameters(self) -> None: + """ + Process schema parameters and handle precedence between 'schema' parameter and individual components. + Also logs warnings for deprecated usage. + """ + # check if both schema and individual components are provided + has_individual_schema_components = any( + [self.entities, self.relations, self.potential_schema] + ) + + if has_individual_schema_components and self.schema_ is not None: + warnings.warn( + "Both 'schema' and individual schema components (entities, relations, potential_schema) " + "were provided. The 'schema' parameter takes precedence. In the future, individual " + "components will be removed. Please use only the 'schema' parameter.", + DeprecationWarning, + stacklevel=2, + ) + + elif has_individual_schema_components: + warnings.warn( + "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " + "and will be removed in a future version. " + "Please use the 'schema' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + + def has_user_provided_schema(self) -> bool: + """Check if the user has provided schema information""" + return bool( + self.entities + or self.relations + or self.potential_schema + or self.schema_ is not None + ) + def _get_pdf_loader(self) -> Optional[PdfLoader]: if not self.from_pdf: return None @@ -114,15 +187,92 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]: def _get_chunk_embedder(self) -> TextChunkEmbedder: return TextChunkEmbedder(embedder=self.get_default_embedder()) - def _get_schema(self) -> SchemaBuilder: + def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: + """ + Get the appropriate schema component based on configuration. + Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema. + """ + if not self.has_user_provided_schema(): + return SchemaFromTextExtractor(llm=self.get_default_llm()) return SchemaBuilder() + def _process_schema_with_precedence( + self, + ) -> Tuple[ + List[SchemaEntity], List[SchemaRelation], Optional[List[Tuple[str, str, str]]] + ]: + """ + Process schema inputs according to precedence rules: + 1. If schema is provided as SchemaConfig object, use it + 2. If schema is provided as dictionary, extract from it + 3. Otherwise, use individual schema components + + Returns: + Tuple of (entities, relations, potential_schema) + """ + if self.schema_ is not None: + # schema takes precedence over individual components + if isinstance(self.schema_, SchemaConfig): + # extract components from SchemaConfig + entity_dicts = list(self.schema_.entities.values()) + # convert dict values to SchemaEntity objects + entities = [SchemaEntity.model_validate(e) for e in entity_dicts] + + # handle case where relations could be None + if self.schema_.relations is not None: + relation_dicts = list(self.schema_.relations.values()) + relations = [ + SchemaRelation.model_validate(r) for r in relation_dicts + ] + else: + relations = [] + + potential_schema = self.schema_.potential_schema + else: + entities = [ + SchemaEntity.from_text_or_dict(e) + for e in cast( + Dict[str, Any], self.schema_.get("entities", {}) + ).values() + ] + relations = [ + SchemaRelation.from_text_or_dict(r) + for r in cast( + Dict[str, Any], self.schema_.get("relations", {}) + ).values() + ] + potential_schema = self.schema_.get("potential_schema") + else: + # use individual components + entities = ( + [SchemaEntity.from_text_or_dict(e) for e in self.entities] + if self.entities + else [] + ) + relations = ( + [SchemaRelation.from_text_or_dict(r) for r in self.relations] + if self.relations + else [] + ) + potential_schema = self.potential_schema + + return entities, relations, potential_schema + def _get_run_params_for_schema(self) -> dict[str, Any]: - return { - "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], - "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], - "potential_schema": self.potential_schema, - } + if not self.has_user_provided_schema(): + # for automatic extraction, the text parameter is needed (will flow through the pipeline connections) + return {} + else: + # process schema components according to precedence rules + entities, relations, potential_schema = ( + self._process_schema_with_precedence() + ) + + return { + "entities": entities, + "relations": relations, + "potential_schema": potential_schema, + } def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( @@ -163,6 +313,17 @@ def _get_connections(self) -> list[ConnectionDefinition]: input_config={"text": "pdf_loader.text"}, ) ) + + # handle automatic schema extraction + if not self.has_user_provided_schema(): + connections.append( + ConnectionDefinition( + start="pdf_loader", + end="schema", + input_config={"text": "pdf_loader.text"}, + ) + ) + connections.append( ConnectionDefinition( start="schema", @@ -178,9 +339,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: ConnectionDefinition( start="schema", end="extractor", - input_config={ - "schema": "schema", - }, + input_config={"schema": "schema"}, ) ) connections.append( @@ -247,4 +406,7 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: "Expected 'text' argument when 'from_pdf' is False." ) run_params["splitter"] = {"text": text} + # Add full text to schema component for automatic schema extraction + if not self.has_user_provided_schema(): + run_params["schema"] = {"text": text} return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index ec5ed5218..c586a7fad 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -15,7 +15,8 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, Any +import logging import neo4j from pydantic import ValidationError @@ -42,6 +43,9 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.experimental.components.schema import SchemaConfig + +logger = logging.getLogger(__name__) class SimpleKGPipeline: @@ -53,17 +57,20 @@ class SimpleKGPipeline: llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. driver (neo4j.Driver): A Neo4j driver instance for database connection. embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks. - entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): A list of either: + schema (Optional[Union[SchemaConfig, dict[str, list]]]): A schema configuration defining entities, + relations, and potential schema relationships. + This is the recommended way to provide schema information. + entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either: - str: entity labels - dict: following the SchemaEntity schema, ie with label, description and properties keys - relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): A list of either: + relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): DEPRECATED. A list of either: - str: relation label - dict: following the SchemaRelation schema, ie with label, description and properties keys - potential_schema (Optional[List[tuple]]): A list of potential schema relationships. + potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. @@ -85,6 +92,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, + schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None, enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, @@ -105,6 +113,7 @@ def __init__( entities=entities or [], relations=relations or [], potential_schema=potential_schema, + schema=schema, enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, diff --git a/src/neo4j_graphrag/generation/__init__.py b/src/neo4j_graphrag/generation/__init__.py index ff3f69f3a..816fe327a 100644 --- a/src/neo4j_graphrag/generation/__init__.py +++ b/src/neo4j_graphrag/generation/__init__.py @@ -1,8 +1,4 @@ from .graphrag import GraphRAG -from .prompts import PromptTemplate, RagTemplate +from .prompts import PromptTemplate, RagTemplate, SchemaExtractionTemplate -__all__ = [ - "GraphRAG", - "PromptTemplate", - "RagTemplate", -] +__all__ = ["GraphRAG", "PromptTemplate", "RagTemplate", "SchemaExtractionTemplate"] diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index ade302720..96bcaf8de 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -200,3 +200,65 @@ def format( text: str = "", ) -> str: return super().format(text=text, schema=schema, examples=examples) + + +class SchemaExtractionTemplate(PromptTemplate): + DEFAULT_TEMPLATE = """ +You are a top-tier algorithm designed for extracting a labeled property graph schema in +structured formats. + +Generate a generalized graph schema based on the input text. Identify key entity types, +their relationship types, and property types. + +IMPORTANT RULES: +1. Return only abstract schema information, not concrete instances. +2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product). +3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES). +4. Include property definitions only when the type can be confidently inferred, otherwise omit them. +5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists. +6. Do not create entity types that aren't clearly mentioned in the text. +7. Keep your schema minimal and focused on clearly identifiable patterns in the text. + +Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, +LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. + +Return a valid JSON object that follows this precise structure: +{{ + "entities": [ + {{ + "label": "Person", + "properties": [ + {{ + "name": "name", + "type": "STRING" + }} + ] + }}, + ... + ], + "relations": [ + {{ + "label": "WORKS_FOR" + }}, + ... + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Company"], + ... + ] +}} + +Examples: +{examples} + +Input text: +{text} +""" + EXPECTED_INPUTS = ["text"] + + def format( + self, + text: str = "", + examples: str = "", + ) -> str: + return super().format(text=text, examples=examples) diff --git a/tests/e2e/experimental/test_simplekgpipeline_e2e.py b/tests/e2e/experimental/test_simplekgpipeline_e2e.py index d30ec3a66..3bd72dd41 100644 --- a/tests/e2e/experimental/test_simplekgpipeline_e2e.py +++ b/tests/e2e/experimental/test_simplekgpipeline_e2e.py @@ -181,6 +181,8 @@ async def test_pipeline_builder_two_documents( driver=driver, embedder=embedder, from_pdf=False, + # provide minimal schema to bypass automatic schema extraction + entities=["Person"], # in order to have 2 chunks: text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), ) @@ -261,6 +263,8 @@ async def test_pipeline_builder_same_document_two_runs( driver=driver, embedder=embedder, from_pdf=False, + # provide minimal schema to bypass automatic schema extraction + entities=["Person"], # in order to have 2 chunks: text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), ) @@ -280,3 +284,120 @@ async def test_pipeline_builder_same_document_two_runs( "MATCH (chunk:Chunk)<-[rel:FROM_CHUNK]-(entity:__Entity__) RETURN chunk, rel, entity" ) assert len(records) == 2 # two entities according to mocked LLMResponse + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_neo4j_for_kg_construction") +async def test_pipeline_builder_with_automatic_schema_extraction( + harry_potter_text_part1: str, + llm: MagicMock, + embedder: MagicMock, + driver: neo4j.Driver, +) -> None: + """Test pipeline with automatic schema extraction (no schema provided). + This test verifies that the pipeline correctly handles automatic schema extraction. + """ + driver.execute_query("MATCH (n) DETACH DELETE n") + embedder.embed_query.return_value = [1, 2, 3] + + # set up mock LLM responses for both schema extraction and entity extraction + llm.ainvoke.side_effect = [ + # first call - schema extraction response + LLMResponse( + content="""{ + "entities": [ + { + "label": "Person", + "description": "A character in the story", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "INTEGER"} + ] + }, + { + "label": "Location", + "description": "A place in the story", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "LOCATED_AT", + "description": "Indicates where a person is located", + "properties": [] + } + ], + "potential_schema": [ + ["Person", "LOCATED_AT", "Location"] + ] + }""" + ), + # second call - entity extraction for first chunk + LLMResponse( + content="""{ + "nodes": [ + { + "id": "0", + "label": "Person", + "properties": { + "name": "Harry Potter" + } + }, + { + "id": "1", + "label": "Location", + "properties": { + "name": "Hogwarts" + } + } + ], + "relationships": [ + { + "type": "LOCATED_AT", + "start_node_id": "0", + "end_node_id": "1" + } + ] + }""" + ), + # third call - entity extraction for second chunk (if text is split) + LLMResponse(content='{"nodes": [], "relationships": []}'), + ] + + # create an instance of the SimpleKGPipeline with NO schema provided + kg_builder_text = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=False, + # use smaller chunk size to ensure we have at least 2 chunks + text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), + ) + + # run the knowledge graph building process with text input + await kg_builder_text.run_async(text=harry_potter_text_part1) + + # verify LLM was called for schema extraction + assert llm.ainvoke.call_count >= 2 + + # verify entities were created + records, _, _ = driver.execute_query("MATCH (n:Person) RETURN n") + assert len(records) == 1 + + # verify locations were created + records, _, _ = driver.execute_query("MATCH (n:Location) RETURN n") + assert len(records) == 1 + + # verify relationships were created + records, _, _ = driver.execute_query( + "MATCH (p:Person)-[r:LOCATED_AT]->(l:Location) RETURN p, r, l" + ) + assert len(records) == 1 + + # verify chunks and relationships to entities + records, _, _ = driver.execute_query( + "MATCH (c:Chunk)<-[:FROM_CHUNK]-(e) RETURN c, e" + ) + assert len(records) >= 1 diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 6ff257a12..be7bbd958 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -14,15 +14,26 @@ # limitations under the License. from __future__ import annotations +import json +from unittest.mock import AsyncMock + import pytest -from neo4j_graphrag.exceptions import SchemaValidationError +from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, SchemaEntity, SchemaProperty, SchemaRelation, + SchemaFromTextExtractor, + SchemaConfig, ) from pydantic import ValidationError +import os +import tempfile +import yaml + +from neo4j_graphrag.generation import PromptTemplate +from neo4j_graphrag.llm.types import LLMResponse @pytest.fixture @@ -93,6 +104,18 @@ def schema_builder() -> SchemaBuilder: return SchemaBuilder() +@pytest.fixture +def schema_config( + schema_builder: SchemaBuilder, + valid_entities: list[SchemaEntity], + valid_relations: list[SchemaRelation], + potential_schema: list[tuple[str, str, str]], +) -> SchemaConfig: + return schema_builder.create_schema_model( + valid_entities, valid_relations, potential_schema + ) + + def test_create_schema_model_valid_data( schema_builder: SchemaBuilder, valid_entities: list[SchemaEntity], @@ -419,3 +442,304 @@ def test_create_schema_model_missing_relations( assert "Relations must also be provided when using a potential schema." in str( exc_info.value ), "Should fail due to missing relations" + + +@pytest.fixture +def mock_llm() -> AsyncMock: + mock = AsyncMock() + mock.ainvoke = AsyncMock() + return mock + + +@pytest.fixture +def valid_schema_json() -> str: + return """ + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + """ + + +@pytest.fixture +def invalid_schema_json() -> str: + return """ + { + "entities": [ + { + "label": "Person", + }, + ], + invalid json content + } + """ + + +@pytest.fixture +def schema_from_text(mock_llm: AsyncMock) -> SchemaFromTextExtractor: + return SchemaFromTextExtractor(llm=mock_llm) + + +@pytest.mark.asyncio +async def test_schema_from_text_run_valid_response( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json: str, +) -> None: + # configure the mock LLM to return a valid schema JSON + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) + + # run the schema extraction + schema_config = await schema_from_text.run(text="Sample text for extraction") + + # verify the LLM was called with a prompt + mock_llm.ainvoke.assert_called_once() + prompt_arg = mock_llm.ainvoke.call_args[0][0] + assert isinstance(prompt_arg, str) + assert "Sample text for extraction" in prompt_arg + + # verify the schema was correctly extracted + assert len(schema_config.entities) == 2 + assert "Person" in schema_config.entities + assert "Organization" in schema_config.entities + + assert schema_config.relations is not None + assert "WORKS_FOR" in schema_config.relations + + assert schema_config.potential_schema is not None + assert len(schema_config.potential_schema) == 1 + assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_run_invalid_json( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + invalid_schema_json: str, +) -> None: + # configure the mock LLM to return invalid JSON + mock_llm.ainvoke.return_value = LLMResponse(content=invalid_schema_json) + + # verify that running with invalid JSON raises a ValueError + with pytest.raises(SchemaExtractionError) as exc_info: + await schema_from_text.run(text="Sample text for extraction") + + assert "not valid JSON" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_schema_from_text_custom_template( + mock_llm: AsyncMock, valid_schema_json: str +) -> None: + # create a custom template + custom_prompt = "This is a custom prompt with text: {text}" + custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) + + # create SchemaFromTextExtractor with the custom template + schema_from_text = SchemaFromTextExtractor( + llm=mock_llm, prompt_template=custom_template + ) + + # configure mock LLM to return valid JSON and capture the prompt that was sent to it + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) + + # run the schema extraction + await schema_from_text.run(text="Sample text") + + # verify the custom prompt was passed to the LLM + prompt_sent_to_llm = mock_llm.ainvoke.call_args[0][0] + assert "This is a custom prompt with text" in prompt_sent_to_llm + + +@pytest.mark.asyncio +async def test_schema_from_text_llm_params( + mock_llm: AsyncMock, valid_schema_json: str +) -> None: + # configure custom LLM parameters + llm_params = {"temperature": 0.1, "max_tokens": 500} + + # create SchemaFromTextExtractor with custom LLM parameters + schema_from_text = SchemaFromTextExtractor(llm=mock_llm, llm_params=llm_params) + + # configure the mock LLM to return a valid schema JSON + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) + + # run the schema extraction + await schema_from_text.run(text="Sample text") + + # verify the LLM was called with the custom parameters + mock_llm.ainvoke.assert_called_once() + call_kwargs = mock_llm.ainvoke.call_args[1] + assert call_kwargs["temperature"] == 0.1 + assert call_kwargs["max_tokens"] == 500 + + +@pytest.mark.asyncio +async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + # create file path + json_path = os.path.join(temp_dir, "schema.json") + + # store the schema config + schema_config.store_as_json(json_path) + + # verify the file exists and has content + assert os.path.exists(json_path) + assert os.path.getsize(json_path) > 0 + + # verify the content is valid JSON and contains expected data + with open(json_path, "r") as f: + data = json.load(f) + assert "entities" in data + assert "PERSON" in data["entities"] + assert "properties" in data["entities"]["PERSON"] + assert "description" in data["entities"]["PERSON"] + assert ( + data["entities"]["PERSON"]["description"] + == "An individual human being." + ) + + +@pytest.mark.asyncio +async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + # Create file path + yaml_path = os.path.join(temp_dir, "schema.yaml") + + # Store the schema config + schema_config.store_as_yaml(yaml_path) + + # Verify the file exists and has content + assert os.path.exists(yaml_path) + assert os.path.getsize(yaml_path) > 0 + + # Verify the content is valid YAML and contains expected data + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + assert "entities" in data + assert "PERSON" in data["entities"] + assert "properties" in data["entities"]["PERSON"] + assert "description" in data["entities"]["PERSON"] + assert ( + data["entities"]["PERSON"]["description"] + == "An individual human being." + ) + + +@pytest.mark.asyncio +async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + # create file paths with different extensions + json_path = os.path.join(temp_dir, "schema.json") + yaml_path = os.path.join(temp_dir, "schema.yaml") + yml_path = os.path.join(temp_dir, "schema.yml") + + # store the schema config in the different formats + schema_config.store_as_json(json_path) + schema_config.store_as_yaml(yaml_path) + schema_config.store_as_yaml(yml_path) + + # load using from_file which should detect the format based on extension + json_schema = SchemaConfig.from_file(json_path) + yaml_schema = SchemaConfig.from_file(yaml_path) + yml_schema = SchemaConfig.from_file(yml_path) + + # simple verification that the objects were loaded correctly + assert isinstance(json_schema, SchemaConfig) + assert isinstance(yaml_schema, SchemaConfig) + assert isinstance(yml_schema, SchemaConfig) + + # verify basic structure is intact + assert "entities" in json_schema.model_dump() + assert "entities" in yaml_schema.model_dump() + assert "entities" in yml_schema.model_dump() + + # verify an unsupported extension raises the correct error + txt_path = os.path.join(temp_dir, "schema.txt") + schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension + + with pytest.raises(ValueError, match="Unsupported file format"): + SchemaConfig.from_file(txt_path) + + +@pytest.fixture +def valid_schema_json_array() -> str: + return """ + [ + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + ] + """ + + +@pytest.mark.asyncio +async def test_schema_from_text_run_valid_json_array( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json_array: str, +) -> None: + # configure the mock LLM to return a valid JSON array + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json_array) + + # run the schema extraction + schema_config = await schema_from_text.run(text="Sample text for extraction") + + # verify the schema was correctly extracted from the array + assert len(schema_config.entities) == 2 + assert "Person" in schema_config.entities + assert "Organization" in schema_config.entities + + assert schema_config.relations is not None + assert "WORKS_FOR" in schema_config.relations + + assert schema_config.potential_schema is not None + assert len(schema_config.potential_schema) == 1 + assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index ef0365849..8aa318cd3 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -28,6 +28,7 @@ SchemaBuilder, SchemaEntity, SchemaRelation, + SchemaFromTextExtractor, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -116,8 +117,21 @@ def test_simple_kg_pipeline_config_chunk_embedder( assert chunk_embedder._embedder == embedder -def test_simple_kg_pipeline_config_schema() -> None: +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +def test_simple_kg_pipeline_config_automatic_schema( + mock_llm: Mock, llm: LLMInterface +) -> None: + mock_llm.return_value = llm config = SimpleKGPipelineConfig() + schema = config._get_schema() + assert isinstance(schema, SchemaFromTextExtractor) + assert schema._llm == llm + + +def test_simple_kg_pipeline_config_manual_schema() -> None: + config = SimpleKGPipelineConfig(entities=["Person"]) assert isinstance(config._get_schema(), SchemaBuilder) @@ -205,9 +219,10 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 5 + assert len(connections) == 6 expected_connections = [ ("pdf_loader", "splitter"), + ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), @@ -240,9 +255,10 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None: perform_entity_resolution=True, ) connections = config._get_connections() - assert len(connections) == 6 + assert len(connections) == 7 expected_connections = [ ("pdf_loader", "splitter"), + ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), @@ -263,7 +279,8 @@ def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None: def test_simple_kg_pipeline_config_run_params_from_text_text() -> None: config = SimpleKGPipelineConfig(from_pdf=False) assert config.get_run_params({"text": "my text"}) == { - "splitter": {"text": "my text"} + "splitter": {"text": "my text"}, + "schema": {"text": "my text"}, }