From 7c831de0efc9789e2c6efb0f3e7475143c473b92 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 23 Apr 2025 14:17:00 +0300 Subject: [PATCH 01/36] Add schema extraction prompt template --- src/neo4j_graphrag/generation/prompts.py | 38 ++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index ade302720..991d20fa6 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -200,3 +200,41 @@ def format( text: str = "", ) -> str: return super().format(text=text, schema=schema, examples=examples) + + +class SchemaExtractionTemplate(PromptTemplate): + DEFAULT_TEMPLATE = """ +You are a top-tier algorithm designed for extracting a labeled property graph schema in +structured formats. + +Generate the generalized graph schema based on input text. Identify key entity types, +their relationship types, and property types whenever it is possible. Return only +abstract schema information, no concrete instances. Use singular PascalCase labels for +entity types and UPPER_SNAKE_CASE for relationship types. Include property definitions +only when the type can be confidently inferred, otherwise omit the properties. +Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, +LOCAL DATETIME, LOCAL TIME, POINT, STRING, ZONED DATETIME, ZONED TIME. +Do not add extra keys or explanatory text. Return a valid JSON object without +back‑ticks, markdown, or comments. + +For example, if the text says "Alice lives in London", the output JSON object should +adhere to the following format: +{"entities": [{"label": "Person", "properties": [{"name": "name", "type": "STRING"}]}, +{"label": "City", "properties":[{"name": "name", "type": "STRING"}]}], +"relations": [{"label": "LIVES_IN"}], +"potential_schema":[[ "Person", "LIVES_IN", "City"]]} + +More examples: +{examples} + +Input text: +{text} +""" + EXPECTED_INPUTS = ["text"] + + def format( + self, + examples: str, + text: str = "", + ) -> str: + return super().format(text=text, examples=examples) From baf9302c54747aebe4ae3dfa5eccdc32d46b3c13 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 23 Apr 2025 16:37:53 +0300 Subject: [PATCH 02/36] Add schema from text using an LLM --- .../experimental/components/schema.py | 63 +++++++++++++++++++ src/neo4j_graphrag/generation/__init__.py | 3 +- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 96d7d466b..8f9bc8944 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -14,9 +14,11 @@ # limitations under the License. from __future__ import annotations +import json from typing import Any, Dict, List, Literal, Optional, Tuple from pydantic import BaseModel, ValidationError, model_validator, validate_call +from requests.exceptions import InvalidJSONError from typing_extensions import Self from neo4j_graphrag.exceptions import SchemaValidationError @@ -25,6 +27,8 @@ EntityInputType, RelationInputType, ) +from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate +from neo4j_graphrag.llm import LLMInterface class SchemaProperty(BaseModel): @@ -236,3 +240,62 @@ async def run( SchemaConfig: A configured schema object, constructed asynchronously. """ return self.create_schema_model(entities, relations, potential_schema) + + +class SchemaFromText(SchemaBuilder): + """ + A builder class for constructing SchemaConfig objects from the output of an LLM after + automatic schema extraction from text. + """ + + def __init__( + self, + llm: LLMInterface, + prompt_template: Optional[PromptTemplate] = None, + llm_params: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self._llm: LLMInterface = llm + self._prompt_template: PromptTemplate = prompt_template or SchemaExtractionTemplate() + self._llm_params: dict[str, Any] = llm_params or {} + + @validate_call + async def run(self, text: str, **kwargs: Any) -> SchemaConfig: + """ + Asynchronously extracts the schema from text and returns a SchemaConfig object. + + Args: + text (str): the text from which the schema will be inferred. + + Returns: + SchemaConfig: A configured schema object, extracted automatically and + constructed asynchronously. + """ + prompt: str = self._prompt_template.format(text=text) + + response = await self._llm.invoke(prompt, **self._llm_params) + content: str = ( + response if isinstance(response, str) else getattr(response, "content", str(response)) + ) + + try: + extracted_schema: Dict[str, Any] = json.loads(content) + except json.JSONDecodeError as exc: + raise InvalidJSONError( + "LLM response is not valid JSON." + ) from exc + + extracted_entities: List[dict] = extracted_schema.get("entities", []) + extracted_relations: Optional[List[dict]] = extracted_schema.get("relations") + potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get("potential_schema") + + entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities] + relations: Optional[List[SchemaRelation]] = ( + [SchemaRelation(**r) for r in extracted_relations] if extracted_relations else None + ) + + return await super().run( + entities=entities, + relations=relations, + potential_schema=potential_schema, + ) diff --git a/src/neo4j_graphrag/generation/__init__.py b/src/neo4j_graphrag/generation/__init__.py index ff3f69f3a..359d57f96 100644 --- a/src/neo4j_graphrag/generation/__init__.py +++ b/src/neo4j_graphrag/generation/__init__.py @@ -1,8 +1,9 @@ from .graphrag import GraphRAG -from .prompts import PromptTemplate, RagTemplate +from .prompts import PromptTemplate, RagTemplate, SchemaExtractionTemplate __all__ = [ "GraphRAG", "PromptTemplate", "RagTemplate", + "SchemaExtractionTemplate" ] From 2b14541d762097922887b168dff4cd99cf261b3b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 24 Apr 2025 16:31:00 +0300 Subject: [PATCH 03/36] Update SimpleKGPipeline for automatic schema extraction --- .../template_pipeline/simple_kg_builder.py | 54 +++++++++++--- .../experimental/pipeline/kg_builder.py | 71 +++++++++++++++++-- 2 files changed, 110 insertions(+), 15 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 306c4eb32..12f1ab49b 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -32,6 +32,7 @@ SchemaBuilder, SchemaEntity, SchemaRelation, + SchemaFromText, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -80,6 +81,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): perform_entity_resolution: bool = True lexical_graph_config: Optional[LexicalGraphConfig] = None neo4j_database: Optional[str] = None + auto_schema_extraction: bool = False pdf_loader: Optional[ComponentType] = None kg_writer: Optional[ComponentType] = None @@ -87,6 +89,10 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) + def has_user_provided_schema(self) -> bool: + """Check if the user has provided schema information""" + return bool(self.entities or self.relations or self.potential_schema) + def _get_pdf_loader(self) -> Optional[PdfLoader]: if not self.from_pdf: return None @@ -114,15 +120,26 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]: def _get_chunk_embedder(self) -> TextChunkEmbedder: return TextChunkEmbedder(embedder=self.get_default_embedder()) - def _get_schema(self) -> SchemaBuilder: + def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]: + """ + Get the appropriate schema component based on configuration. + Return SchemaFromText for automatic extraction or SchemaBuilder for manual schema. + """ + if self.auto_schema_extraction and not self.has_user_provided_schema(): + return SchemaFromText(llm=self.get_default_llm()) return SchemaBuilder() def _get_run_params_for_schema(self) -> dict[str, Any]: - return { - "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], - "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], - "potential_schema": self.potential_schema, - } + if self.auto_schema_extraction and not self.has_user_provided_schema(): + # for automatic extraction, the text parameter is needed (will flow through the pipeline connections) + return {} + else: + # for manual schema, use the provided entities/relations/potential_schema + return { + "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], + "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], + "potential_schema": self.potential_schema, + } def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( @@ -163,6 +180,17 @@ def _get_connections(self) -> list[ConnectionDefinition]: input_config={"text": "pdf_loader.text"}, ) ) + + # handle automatic schema extraction + if self.auto_schema_extraction and not self.has_user_provided_schema(): + connections.append( + ConnectionDefinition( + start="pdf_loader", + end="schema", + input_config={"text": "pdf_loader.text"}, + ) + ) + connections.append( ConnectionDefinition( start="schema", @@ -174,13 +202,21 @@ def _get_connections(self) -> list[ConnectionDefinition]: ) ) else: + # handle automatic schema extraction for direct text input: ensure schema extraction uses the complete text + if self.auto_schema_extraction and not self.has_user_provided_schema(): + connections.append( + ConnectionDefinition( + start="__input__", # connection to pipeline input + end="schema", + input_config={"text": "text"}, # use the original text input + ) + ) + connections.append( ConnectionDefinition( start="schema", end="extractor", - input_config={ - "schema": "schema", - }, + input_config={"schema": "schema"}, ) ) connections.append( diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index ec5ed5218..9ff222d32 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -16,6 +16,8 @@ from __future__ import annotations from typing import List, Optional, Sequence, Union +import logging +import warnings import neo4j from pydantic import ValidationError @@ -42,7 +44,9 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.experimental.components.schema import SchemaConfig, SchemaBuilder +logger = logging.getLogger(__name__) class SimpleKGPipeline: """ @@ -53,17 +57,20 @@ class SimpleKGPipeline: llm (LLMInterface): An instance of an LLM to use for entity and relation extraction. driver (neo4j.Driver): A Neo4j driver instance for database connection. embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks. - entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): A list of either: + schema (Optional[Union[SchemaConfig, dict[str, list]]]): A schema configuration defining entities, + relations, and potential schema relationships. + This is the recommended way to provide schema information. + entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either: - str: entity labels - dict: following the SchemaEntity schema, ie with label, description and properties keys - relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): A list of either: + relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): DEPRECATED. A list of either: - str: relation label - dict: following the SchemaRelation schema, ie with label, description and properties keys - potential_schema (Optional[List[tuple]]): A list of potential schema relationships. + potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. @@ -85,6 +92,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, + schema: Optional[Union[SchemaConfig, dict[str, list]]] = None, enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, @@ -96,15 +104,65 @@ def __init__( lexical_graph_config: Optional[LexicalGraphConfig] = None, neo4j_database: Optional[str] = None, ): + # deprecation warnings for old parameters + if any([entities, relations, potential_schema]) and schema is not None: + logger.warning( + "Both 'schema' and individual schema components (entities, relations, potential_schema) " + "were provided. The 'schema' parameter takes precedence. In the future, individual " + "components will be removed. Please use only the 'schema' parameter." + ) + # emit a DeprecationWarning for tools that might be monitoring for it + warnings.warn( + "Both 'schema' and individual schema components are provided. Use only 'schema'.", + DeprecationWarning, + stacklevel=2, + ) + elif any([entities, relations, potential_schema]): + logger.warning( + "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " + "and will be removed in a future version. " + "Please use the 'schema' parameter instead." + ) + warnings.warn( + "The 'entities', 'relations', and 'potential_schema' parameters are deprecated.", + DeprecationWarning, + stacklevel=2, + ) + + # handle schema precedence over individual schema components + schema_entities = [] + schema_relations = [] + schema_potential = None + + if schema is not None: + # schema takes precedence over individual components + if isinstance(schema, SchemaConfig): + # use the SchemaConfig directly + pass + else: + # convert dictionary to entity/relation lists + schema_entities = schema.get("entities", []) + schema_relations = schema.get("relations", []) + schema_potential = schema.get("potential_schema") + else: + # Use the individual components if provided + schema_entities = entities or [] + schema_relations = relations or [] + schema_potential = potential_schema + + # determine if automatic schema extraction should be performed + has_schema = bool(schema_entities or schema_relations or schema_potential or isinstance(schema, SchemaConfig)) + auto_schema_extraction = not has_schema + try: config = SimpleKGPipelineConfig( # argument type are fixed in the Config object llm_config=llm, # type: ignore[arg-type] neo4j_config=driver, # type: ignore[arg-type] embedder_config=embedder, # type: ignore[arg-type] - entities=entities or [], - relations=relations or [], - potential_schema=potential_schema, + entities=schema_entities, + relations=schema_relations, + potential_schema=schema_potential, enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, @@ -115,6 +173,7 @@ def __init__( perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, neo4j_database=neo4j_database, + auto_schema_extraction=auto_schema_extraction, ) except (ValidationError, ValueError) as e: raise PipelineDefinitionError() from e From 49452d4e6ff2379d3cb6efea1654fe5f683c502c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 24 Apr 2025 18:05:48 +0300 Subject: [PATCH 04/36] Save/Read inferred schema --- .../experimental/components/schema.py | 91 ++++++++++++++++++- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8f9bc8944..a3cdf203a 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -15,7 +15,9 @@ from __future__ import annotations import json -from typing import Any, Dict, List, Literal, Optional, Tuple +import yaml +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from pathlib import Path from pydantic import BaseModel, ValidationError, model_validator, validate_call from requests.exceptions import InvalidJSONError @@ -127,6 +129,91 @@ def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: return data + def store_as_json(self, file_path: str) -> None: + """ + Save the schema configuration to a JSON file. + + Args: + file_path (str): The path where the schema configuration will be saved. + """ + with open(file_path, 'w') as f: + json.dump(self.model_dump(), f, indent=2) + + def store_as_yaml(self, file_path: str) -> None: + """ + Save the schema configuration to a YAML file. + + Args: + file_path (str): The path where the schema configuration will be saved. + """ + with open(file_path, 'w') as f: + yaml.dump(self.model_dump(), f, default_flow_style=False, sort_keys=False) + + @classmethod + def from_file(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a file (either JSON or YAML). + + The file format is automatically detected based on the file extension. + + Args: + file_path (Union[str, Path]): The path to the schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + file_path = Path(file_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Schema file not found: {file_path}") + + if file_path.suffix.lower() in ['.json']: + return cls.from_json(file_path) + elif file_path.suffix.lower() in ['.yaml', '.yml']: + return cls.from_yaml(file_path) + else: + raise ValueError(f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml") + + @classmethod + def from_json(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a JSON file. + + Args: + file_path (Union[str, Path]): The path to the JSON schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + with open(file_path, 'r') as f: + try: + data = json.load(f) + return cls.model_validate(data) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON file: {e}") + except ValidationError as e: + raise SchemaValidationError(f"Schema validation failed: {e}") + + @classmethod + def from_yaml(cls, file_path: Union[str, Path]) -> Self: + """ + Load a schema configuration from a YAML file. + + Args: + file_path (Union[str, Path]): The path to the YAML schema configuration file. + + Returns: + SchemaConfig: The loaded schema configuration. + """ + with open(file_path, 'r') as f: + try: + data = yaml.safe_load(f) + return cls.model_validate(data) + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML file: {e}") + except ValidationError as e: + raise SchemaValidationError(f"Schema validation failed: {e}") + class SchemaBuilder(Component): """ @@ -281,7 +368,7 @@ async def run(self, text: str, **kwargs: Any) -> SchemaConfig: try: extracted_schema: Dict[str, Any] = json.loads(content) except json.JSONDecodeError as exc: - raise InvalidJSONError( + raise ValueError( "LLM response is not valid JSON." ) from exc From fa8a6af0cbee667ad295063c4d684415671e7fa5 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 25 Apr 2025 18:21:19 +0300 Subject: [PATCH 05/36] Bug fixes --- .../experimental/components/schema.py | 13 +++++++++---- src/neo4j_graphrag/generation/prompts.py | 14 +++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index a3cdf203a..b96e100ee 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -146,8 +146,13 @@ def store_as_yaml(self, file_path: str) -> None: Args: file_path (str): The path where the schema configuration will be saved. """ + # create a copy of the data and convert tuples to lists for YAML compatibility + data = self.model_dump() + if data.get('potential_schema'): + data['potential_schema'] = [list(item) for item in data['potential_schema']] + with open(file_path, 'w') as f: - yaml.dump(self.model_dump(), f, default_flow_style=False, sort_keys=False) + yaml.dump(data, f, default_flow_style=False, sort_keys=False) @classmethod def from_file(cls, file_path: Union[str, Path]) -> Self: @@ -347,18 +352,18 @@ def __init__( self._llm_params: dict[str, Any] = llm_params or {} @validate_call - async def run(self, text: str, **kwargs: Any) -> SchemaConfig: + async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig: """ Asynchronously extracts the schema from text and returns a SchemaConfig object. Args: text (str): the text from which the schema will be inferred. - + examples (str): examples to guide schema extraction. Returns: SchemaConfig: A configured schema object, extracted automatically and constructed asynchronously. """ - prompt: str = self._prompt_template.format(text=text) + prompt: str = self._prompt_template.format(text=text, examples=examples) response = await self._llm.invoke(prompt, **self._llm_params) content: str = ( diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 991d20fa6..6ed72463b 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -219,10 +219,10 @@ class SchemaExtractionTemplate(PromptTemplate): For example, if the text says "Alice lives in London", the output JSON object should adhere to the following format: -{"entities": [{"label": "Person", "properties": [{"name": "name", "type": "STRING"}]}, -{"label": "City", "properties":[{"name": "name", "type": "STRING"}]}], -"relations": [{"label": "LIVES_IN"}], -"potential_schema":[[ "Person", "LIVES_IN", "City"]]} +{{"entities": [{{"label": "Person", "properties": [{{"name": "name", "type": "STRING"}}]}}, +{{"label": "City", "properties":[{{"name": "name", "type": "STRING"}}]}}], +"relations": [{{"label": "LIVES_IN"}}], +"potential_schema":[[ "Person", "LIVES_IN", "City"]]}} More examples: {examples} @@ -233,8 +233,8 @@ class SchemaExtractionTemplate(PromptTemplate): EXPECTED_INPUTS = ["text"] def format( - self, - examples: str, - text: str = "", + self, + text: str = "", + examples: str = "", ) -> str: return super().format(text=text, examples=examples) From b52bed4467d82dbd3fd9227da52ce7da645b1e0c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 25 Apr 2025 18:21:26 +0300 Subject: [PATCH 06/36] Add unit tests --- .../experimental/components/test_schema.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 6ff257a12..cda8d081f 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -14,6 +14,9 @@ # limitations under the License. from __future__ import annotations +import json +from unittest.mock import AsyncMock + import pytest from neo4j_graphrag.exceptions import SchemaValidationError from neo4j_graphrag.experimental.components.schema import ( @@ -21,8 +24,16 @@ SchemaEntity, SchemaProperty, SchemaRelation, + SchemaFromText, + SchemaConfig, ) from pydantic import ValidationError +import os +import tempfile +import yaml +from pathlib import Path + +from neo4j_graphrag.generation import PromptTemplate @pytest.fixture @@ -93,6 +104,18 @@ def schema_builder() -> SchemaBuilder: return SchemaBuilder() +@pytest.fixture +def schema_config( + schema_builder: SchemaBuilder, + valid_entities: list[SchemaEntity], + valid_relations: list[SchemaRelation], + potential_schema: list[tuple[str, str, str]], +): + return schema_builder.create_schema_model( + valid_entities, valid_relations, potential_schema + ) + + def test_create_schema_model_valid_data( schema_builder: SchemaBuilder, valid_entities: list[SchemaEntity], @@ -419,3 +442,224 @@ def test_create_schema_model_missing_relations( assert "Relations must also be provided when using a potential schema." in str( exc_info.value ), "Should fail due to missing relations" + + +@pytest.fixture +def mock_llm(): + mock = AsyncMock() + mock.invoke = AsyncMock() + return mock + + +@pytest.fixture +def valid_schema_json(): + return ''' + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + ''' + + +@pytest.fixture +def invalid_schema_json(): + return ''' + { + "entities": [ + { + "label": "Person", + }, + ], + invalid json content + } + ''' + + +@pytest.fixture +def schema_from_text(mock_llm): + return SchemaFromText(llm=mock_llm) + + +@pytest.mark.asyncio +async def test_schema_from_text_run_valid_response(schema_from_text, mock_llm, valid_schema_json): + # configure the mock LLM to return a valid schema JSON + mock_llm.invoke.return_value = valid_schema_json + + # run the schema extraction + schema_config = await schema_from_text.run(text="Sample text for extraction") + + # verify the LLM was called with a prompt + mock_llm.invoke.assert_called_once() + prompt_arg = mock_llm.invoke.call_args[0][0] + assert isinstance(prompt_arg, str) + assert "Sample text for extraction" in prompt_arg + + # verify the schema was correctly extracted + assert len(schema_config.entities) == 2 + assert "Person" in schema_config.entities + assert "Organization" in schema_config.entities + + assert schema_config.relations is not None + assert "WORKS_FOR" in schema_config.relations + + assert schema_config.potential_schema is not None + assert len(schema_config.potential_schema) == 1 + assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") + + +@pytest.mark.asyncio +async def test_schema_from_text_run_invalid_json(schema_from_text, mock_llm, invalid_schema_json): + # configure the mock LLM to return invalid JSON + mock_llm.invoke.return_value = invalid_schema_json + + # verify that running with invalid JSON raises a ValueError + with pytest.raises(ValueError) as exc_info: + await schema_from_text.run(text="Sample text for extraction") + + assert "not valid JSON" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_schema_from_text_custom_template(mock_llm, valid_schema_json): + # create a custom template + custom_prompt = "This is a custom prompt with text: {text}" + custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) + + # create SchemaFromText with the custom template + schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template) + + # configure mock LLM to return valid JSON and capture the prompt that was sent to it + mock_llm.invoke.return_value = valid_schema_json + + # run the schema extraction + await schema_from_text.run(text="Sample text") + + # verify the custom prompt was passed to the LLM + prompt_sent_to_llm = mock_llm.invoke.call_args[0][0] + assert "This is a custom prompt with text" in prompt_sent_to_llm + + +@pytest.mark.asyncio +async def test_schema_from_text_llm_params(mock_llm, valid_schema_json): + # configure custom LLM parameters + llm_params = {"temperature": 0.1, "max_tokens": 500} + + # create SchemaFromText with custom LLM parameters + schema_from_text = SchemaFromText(llm=mock_llm, llm_params=llm_params) + + # configure the mock LLM to return a valid schema JSON + mock_llm.invoke.return_value = valid_schema_json + + # run the schema extraction + await schema_from_text.run(text="Sample text") + + # verify the LLM was called with the custom parameters + mock_llm.invoke.assert_called_once() + call_kwargs = mock_llm.invoke.call_args[1] + assert call_kwargs["temperature"] == 0.1 + assert call_kwargs["max_tokens"] == 500 + + +@pytest.mark.asyncio +async def test_schema_config_store_as_json(schema_config): + with tempfile.TemporaryDirectory() as temp_dir: + # create file path + json_path = os.path.join(temp_dir, "schema.json") + + # store the schema config + schema_config.store_as_json(json_path) + + # verify the file exists and has content + assert os.path.exists(json_path) + assert os.path.getsize(json_path) > 0 + + # verify the content is valid JSON and contains expected data + with open(json_path, 'r') as f: + data = json.load(f) + assert "entities" in data + assert "PERSON" in data["entities"] + assert "properties" in data["entities"]["PERSON"] + assert "description" in data["entities"]["PERSON"] + assert data["entities"]["PERSON"]["description"] == "An individual human being." + + +@pytest.mark.asyncio +async def test_schema_config_store_as_yaml(schema_config): + with tempfile.TemporaryDirectory() as temp_dir: + # Create file path + yaml_path = os.path.join(temp_dir, "schema.yaml") + + # Store the schema config + schema_config.store_as_yaml(yaml_path) + + # Verify the file exists and has content + assert os.path.exists(yaml_path) + assert os.path.getsize(yaml_path) > 0 + + # Verify the content is valid YAML and contains expected data + with open(yaml_path, 'r') as f: + data = yaml.safe_load(f) + assert "entities" in data + assert "PERSON" in data["entities"] + assert "properties" in data["entities"]["PERSON"] + assert "description" in data["entities"]["PERSON"] + assert data["entities"]["PERSON"]["description"] == "An individual human being." + + +@pytest.mark.asyncio +async def test_schema_config_from_file(schema_config): + with tempfile.TemporaryDirectory() as temp_dir: + # create file paths with different extensions + json_path = os.path.join(temp_dir, "schema.json") + yaml_path = os.path.join(temp_dir, "schema.yaml") + yml_path = os.path.join(temp_dir, "schema.yml") + + # store the schema config in the different formats + schema_config.store_as_json(json_path) + schema_config.store_as_yaml(yaml_path) + schema_config.store_as_yaml(yml_path) + + # load using from_file which should detect the format based on extension + json_schema = SchemaConfig.from_file(json_path) + yaml_schema = SchemaConfig.from_file(yaml_path) + yml_schema = SchemaConfig.from_file(yml_path) + + # simple verification that the objects were loaded correctly + assert isinstance(json_schema, SchemaConfig) + assert isinstance(yaml_schema, SchemaConfig) + assert isinstance(yml_schema, SchemaConfig) + + # verify basic structure is intact + assert "entities" in json_schema.model_dump() + assert "entities" in yaml_schema.model_dump() + assert "entities" in yml_schema.model_dump() + + # verify an unsupported extension raises the correct error + txt_path = os.path.join(temp_dir, "schema.txt") + schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension + + with pytest.raises(ValueError, match="Unsupported file format"): + SchemaConfig.from_file(txt_path) From 41d359de64f6fdbcdbbc525ea877952f561de2fc Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 28 Apr 2025 18:22:01 +0300 Subject: [PATCH 07/36] Allow schema parameter in SimpleKGBuilderConfig and refactor code --- .../template_pipeline/simple_kg_builder.py | 82 +++++++++++++++++-- .../experimental/pipeline/kg_builder.py | 62 ++------------ 2 files changed, 81 insertions(+), 63 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 12f1ab49b..c1b37b946 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, Literal, Optional, Sequence, Union +from typing import Any, ClassVar, Literal, Optional, Sequence, Union, TypeVar +import logging -from pydantic import ConfigDict +from pydantic import ConfigDict, model_validator, Field from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -33,6 +34,7 @@ SchemaEntity, SchemaRelation, SchemaFromText, + SchemaConfig, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -55,6 +57,9 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate +logger = logging.getLogger(__name__) + +T = TypeVar('T', bound='SimpleKGPipelineConfig') class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ @@ -75,6 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None + schema: Optional[Union[SchemaConfig, dict[str, list]]] = None enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -88,10 +94,40 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): text_splitter: Optional[ComponentType] = None model_config = ConfigDict(arbitrary_types_allowed=True) + + @model_validator(mode='after') + def handle_schema_precedence(self) -> T: + """Handle schema precedence and warnings""" + self._process_schema_parameters() + return self + + def _process_schema_parameters(self) -> None: + """ + Process schema parameters and handle precedence between 'schema' parameter and individual components. + Also logs warnings for deprecated usage. + """ + # check if both schema and individual components are provided + has_individual_schema_components = any([self.entities, self.relations, self.potential_schema]) + + if has_individual_schema_components and self.schema is not None: + logger.warning( + "Both 'schema' and individual schema components (entities, relations, potential_schema) " + "were provided. The 'schema' parameter takes precedence. In the future, individual " + "components will be removed. Please use only the 'schema' parameter.", + stacklevel=2 + ) + + elif has_individual_schema_components: + logger.warning( + "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " + "and will be removed in a future version. " + "Please use the 'schema' parameter instead.", + stacklevel=2 + ) def has_user_provided_schema(self) -> bool: """Check if the user has provided schema information""" - return bool(self.entities or self.relations or self.potential_schema) + return bool(self.entities or self.relations or self.potential_schema or self.schema is not None) def _get_pdf_loader(self) -> Optional[PdfLoader]: if not self.from_pdf: @@ -129,16 +165,48 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]: return SchemaFromText(llm=self.get_default_llm()) return SchemaBuilder() + def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]]: + """ + Process schema inputs according to precedence rules: + 1. If schema is provided as SchemaConfig object, use it + 2. If schema is provided as dictionary, extract from it + 3. Otherwise, use individual schema components + + Returns: + Tuple of (entities, relations, potential_schema) + """ + if self.schema is not None: + # schema takes precedence over individual components + if isinstance(self.schema, SchemaConfig): + # extract components from SchemaConfig + entities = list(self.schema.entities.values()) + relations = list(self.schema.relations.values()) + potential_schema = self.schema.potential_schema + else: + # extract from dictionary + entities = [SchemaEntity.from_text_or_dict(e) for e in self.schema.get("entities", [])] + relations = [SchemaRelation.from_text_or_dict(r) for r in self.schema.get("relations", [])] + potential_schema = self.schema.get("potential_schema") + else: + # use individual components + entities = [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else [] + relations = [SchemaRelation.from_text_or_dict(r) for r in self.relations] if self.relations else [] + potential_schema = self.potential_schema + + return entities, relations, potential_schema + def _get_run_params_for_schema(self) -> dict[str, Any]: if self.auto_schema_extraction and not self.has_user_provided_schema(): # for automatic extraction, the text parameter is needed (will flow through the pipeline connections) return {} else: - # for manual schema, use the provided entities/relations/potential_schema + # process schema components according to precedence rules + entities, relations, potential_schema = self._process_schema_with_precedence() + return { - "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], - "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], - "potential_schema": self.potential_schema, + "entities": entities, + "relations": relations, + "potential_schema": potential_schema, } def _get_extractor(self) -> EntityRelationExtractor: diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 9ff222d32..49b9da3b2 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -17,7 +17,6 @@ from typing import List, Optional, Sequence, Union import logging -import warnings import neo4j from pydantic import ValidationError @@ -44,7 +43,7 @@ ) from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.experimental.components.schema import SchemaConfig, SchemaBuilder +from neo4j_graphrag.experimental.components.schema import SchemaConfig logger = logging.getLogger(__name__) @@ -104,65 +103,16 @@ def __init__( lexical_graph_config: Optional[LexicalGraphConfig] = None, neo4j_database: Optional[str] = None, ): - # deprecation warnings for old parameters - if any([entities, relations, potential_schema]) and schema is not None: - logger.warning( - "Both 'schema' and individual schema components (entities, relations, potential_schema) " - "were provided. The 'schema' parameter takes precedence. In the future, individual " - "components will be removed. Please use only the 'schema' parameter." - ) - # emit a DeprecationWarning for tools that might be monitoring for it - warnings.warn( - "Both 'schema' and individual schema components are provided. Use only 'schema'.", - DeprecationWarning, - stacklevel=2, - ) - elif any([entities, relations, potential_schema]): - logger.warning( - "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " - "and will be removed in a future version. " - "Please use the 'schema' parameter instead." - ) - warnings.warn( - "The 'entities', 'relations', and 'potential_schema' parameters are deprecated.", - DeprecationWarning, - stacklevel=2, - ) - - # handle schema precedence over individual schema components - schema_entities = [] - schema_relations = [] - schema_potential = None - - if schema is not None: - # schema takes precedence over individual components - if isinstance(schema, SchemaConfig): - # use the SchemaConfig directly - pass - else: - # convert dictionary to entity/relation lists - schema_entities = schema.get("entities", []) - schema_relations = schema.get("relations", []) - schema_potential = schema.get("potential_schema") - else: - # Use the individual components if provided - schema_entities = entities or [] - schema_relations = relations or [] - schema_potential = potential_schema - - # determine if automatic schema extraction should be performed - has_schema = bool(schema_entities or schema_relations or schema_potential or isinstance(schema, SchemaConfig)) - auto_schema_extraction = not has_schema - try: config = SimpleKGPipelineConfig( # argument type are fixed in the Config object llm_config=llm, # type: ignore[arg-type] neo4j_config=driver, # type: ignore[arg-type] embedder_config=embedder, # type: ignore[arg-type] - entities=schema_entities, - relations=schema_relations, - potential_schema=schema_potential, + entities=entities or [], + relations=relations or [], + potential_schema=potential_schema, + schema=schema, enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, @@ -173,7 +123,7 @@ def __init__( perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, neo4j_database=neo4j_database, - auto_schema_extraction=auto_schema_extraction, + auto_schema_extraction=not bool(schema or entities or relations or potential_schema), ) except (ValidationError, ValueError) as e: raise PipelineDefinitionError() from e From 511bc3ed9cade8a1cf90197d84c2d0aea5a33016 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 28 Apr 2025 18:46:12 +0300 Subject: [PATCH 08/36] Update changelog and api rst --- CHANGELOG.md | 1 + docs/source/api.rst | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a0533cfa..f44ea792b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Added a `Pipeline.stream` method to stream pipeline progress. - Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged. - Added a new fuzzy match resolver to the KG Builder for entity resolution based on RapiFuzz string fuzzy matching. +- Added support for automatic schema extraction from text using LLMs. ### Changed diff --git a/docs/source/api.rst b/docs/source/api.rst index e895cd5dd..2567411af 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -77,6 +77,12 @@ SchemaBuilder .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaBuilder :members: run +SchemaFromText +============= + +.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromText + :members: run + EntityRelationExtractor ======================= @@ -362,6 +368,13 @@ ERExtractionTemplate :members: :exclude-members: format +SchemaExtractionTemplate +------------------------ + +.. autoclass:: neo4j_graphrag.generation.prompts.SchemaExtractionTemplate + :members: + :exclude-members: format + Text2CypherTemplate -------------------- From 212ae0b04b17c480d0faeaebb2874625cf3982d5 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 28 Apr 2025 19:34:37 +0300 Subject: [PATCH 09/36] Update documentation --- docs/source/user_guide_kg_builder.rst | 201 +++++++++++++++++--------- 1 file changed, 136 insertions(+), 65 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 11a6d1741..3c4716176 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -21,7 +21,7 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of - **Data loader**: extract text from files (PDFs, ...). - **Text splitter**: split the text into smaller pieces of text (chunks), manageable by the LLM context window (token limit). - **Chunk embedder** (optional): compute the chunk embeddings. -- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. +- **Schema builder**: provide a schema to ground the LLM extracted entities and relations and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs. - **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional). - **Entity and relation extractor**: extract relevant entities and relations from the text. - **Knowledge Graph writer**: save the identified entities and relations. @@ -75,10 +75,11 @@ Graph Schema It is possible to guide the LLM by supplying a list of entities, relationships, and instructions on how to connect them. However, note that the extracted graph -may not fully adhere to these guidelines. Entities and relationships can be -represented as either simple strings (for their labels) or dictionaries. If using -a dictionary, it must include a label key and can optionally include description -and properties keys, as shown below: +may not fully adhere to these guidelines unless schema enforcement is enabled +(see :ref:`Schema Enforcement Behaviour`). Entities and relationships can be represented +as either simple strings (for their labels) or dictionaries. If using a dictionary, +it must include a label key and can optionally include description and properties keys, +as shown below: .. code:: python @@ -117,6 +118,18 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated .. code:: python + # Using the schema parameter (recommended approach) + kg_builder = SimpleKGPipeline( + # ... + schema={ + "entities": ENTITIES, + "relations": RELATIONS, + "potential_schema": POTENTIAL_SCHEMA + }, + # ... + ) + + # Using individual schema parameters (deprecated approach) kg_builder = SimpleKGPipeline( # ... entities=ENTITIES, @@ -125,6 +138,9 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated # ... ) +.. note:: + By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromText` section. + Extra configurations -------------------- @@ -412,41 +428,46 @@ within the configuration file. "neo4j_database": "myDb", "on_error": "IGNORE", "prompt_template": "...", - "entities": [ - "Person", - { - "label": "House", - "description": "Family the person belongs to", - "properties": [ - {"name": "name", "type": "STRING"} - ] - }, - { - "label": "Planet", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "weather", "type": "STRING"} - ] - } - ], - "relations": [ - "PARENT_OF", - { - "label": "HEIR_OF", - "description": "Used for inheritor relationship between father and sons" - }, - { - "label": "RULES", - "properties": [ - {"name": "fromYear", "type": "INTEGER"} - ] - } - ], - "potential_schema": [ - ["Person", "PARENT_OF", "Person"], - ["Person", "HEIR_OF", "House"], - ["House", "RULES", "Planet"] - ], + + "schema": { + "entities": [ + "Person", + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Planet", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "weather", "type": "STRING"} + ] + } + ], + "relations": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + {"name": "fromYear", "type": "INTEGER"} + ] + } + ], + "potential_schema": [ + ["Person", "PARENT_OF", "Person"], + ["Person", "HEIR_OF", "House"], + ["House", "RULES", "Planet"] + ] + }, + /* Control automatic schema extraction */ + "auto_schema_extraction": false, "lexical_graph_config": { "chunk_node_label": "TextPart" } @@ -462,31 +483,36 @@ or in YAML: neo4j_database: myDb on_error: IGNORE prompt_template: ... - entities: - - label: Person - - label: House - description: Family the person belongs to - properties: - - name: name - type: STRING - - label: Planet - properties: - - name: name - type: STRING - - name: weather - type: STRING - relations: - - label: PARENT_OF - - label: HEIR_OF - description: Used for inheritor relationship between father and sons - - label: RULES - properties: - - name: fromYear - type: INTEGER - potential_schema: - - ["Person", "PARENT_OF", "Person"] - - ["Person", "HEIR_OF", "House"] - - ["House", "RULES", "Planet"] + + # Using the schema parameter (recommended approach) + schema: + entities: + - Person + - label: House + description: Family the person belongs to + properties: + - name: name + type: STRING + - label: Planet + properties: + - name: name + type: STRING + - name: weather + type: STRING + relations: + - PARENT_OF + - label: HEIR_OF + description: Used for inheritor relationship between father and sons + - label: RULES + properties: + - name: fromYear + type: INTEGER + potential_schema: + - ["Person", "PARENT_OF", "Person"] + - ["Person", "HEIR_OF", "House"] + - ["House", "RULES", "Planet"] + # Control automatic schema extraction + auto_schema_extraction: false lexical_graph_config: chunk_node_label: TextPart @@ -791,6 +817,49 @@ Here is a code block illustrating these concepts: After validation, this schema is saved in a `SchemaConfig` object, whose dict representation is passed to the LLM. +Automatic Schema Extraction with SchemaFromText +---------------------------------------------- +.. _automatic-schema-extraction: + +Instead of manually defining the schema, you can use the `SchemaFromText` component to automatically extract a schema from your text using an LLM: + +.. code:: python + + from neo4j_graphrag.experimental.components.schema import SchemaFromText + from neo4j_graphrag.llm import OpenAILLM + + # Create the automatic schema extractor + schema_extractor = SchemaFromText( + llm=OpenAILLM( + model_name="gpt-4o", + model_params={ + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + }, + ) + ) + + # Extract schema from text + schema_config = await schema_extractor.run(text="Your document text here...") + + # Use the extracted schema with other components + extractor = LLMEntityRelationExtractor(llm=llm) + result = await extractor.run(chunks=chunks, schema=schema_config) + +The `SchemaFromText` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. + +You can also save and reload the extracted schema: + +.. code:: python + + # Save the schema to JSON or YAML files + schema_config.store_as_json("my_schema.json") + schema_config.store_as_yaml("my_schema.yaml") + + # Later, reload the schema from file + from neo4j_graphrag.experimental.components.schema import SchemaConfig + restored_schema = SchemaConfig.from_file("my_schema.json") # or my_schema.yaml + Entity and Relation Extractor ============================= @@ -832,6 +901,8 @@ The LLM to use can be customized, the only constraint is that it obeys the :ref: Schema Enforcement Behaviour ---------------------------- +.. _schema-enforcement-behaviour: + By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: From 30c273d0587fb5465976dc058380e3de1cd4c081 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 09:00:28 +0300 Subject: [PATCH 10/36] Fix Changelog after rebase --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f44ea792b..a6d17340a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Added + +- Added support for automatic schema extraction from text using LLMs. + ## 1.7.0 ### Added @@ -11,7 +15,6 @@ - Added a `Pipeline.stream` method to stream pipeline progress. - Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged. - Added a new fuzzy match resolver to the KG Builder for entity resolution based on RapiFuzz string fuzzy matching. -- Added support for automatic schema extraction from text using LLMs. ### Changed From 52a2686155ab08c80856c3d64cf2355464354653 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 09:02:55 +0300 Subject: [PATCH 11/36] Ruff --- .../experimental/components/schema.py | 72 ++++++++++-------- .../template_pipeline/simple_kg_builder.py | 72 ++++++++++++------ .../experimental/pipeline/kg_builder.py | 5 +- src/neo4j_graphrag/generation/__init__.py | 7 +- .../experimental/components/test_schema.py | 76 +++++++++++-------- 5 files changed, 138 insertions(+), 94 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index b96e100ee..2366df4ea 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -136,61 +136,63 @@ def store_as_json(self, file_path: str) -> None: Args: file_path (str): The path where the schema configuration will be saved. """ - with open(file_path, 'w') as f: + with open(file_path, "w") as f: json.dump(self.model_dump(), f, indent=2) - + def store_as_yaml(self, file_path: str) -> None: """ Save the schema configuration to a YAML file. Args: file_path (str): The path where the schema configuration will be saved. - """ + """ # create a copy of the data and convert tuples to lists for YAML compatibility data = self.model_dump() - if data.get('potential_schema'): - data['potential_schema'] = [list(item) for item in data['potential_schema']] - - with open(file_path, 'w') as f: + if data.get("potential_schema"): + data["potential_schema"] = [list(item) for item in data["potential_schema"]] + + with open(file_path, "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) - + @classmethod def from_file(cls, file_path: Union[str, Path]) -> Self: """ Load a schema configuration from a file (either JSON or YAML). - + The file format is automatically detected based on the file extension. - + Args: file_path (Union[str, Path]): The path to the schema configuration file. - + Returns: SchemaConfig: The loaded schema configuration. """ file_path = Path(file_path) - + if not file_path.exists(): raise FileNotFoundError(f"Schema file not found: {file_path}") - - if file_path.suffix.lower() in ['.json']: + + if file_path.suffix.lower() in [".json"]: return cls.from_json(file_path) - elif file_path.suffix.lower() in ['.yaml', '.yml']: + elif file_path.suffix.lower() in [".yaml", ".yml"]: return cls.from_yaml(file_path) else: - raise ValueError(f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml") - + raise ValueError( + f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" + ) + @classmethod def from_json(cls, file_path: Union[str, Path]) -> Self: """ Load a schema configuration from a JSON file. - + Args: file_path (Union[str, Path]): The path to the JSON schema configuration file. - + Returns: SchemaConfig: The loaded schema configuration. """ - with open(file_path, 'r') as f: + with open(file_path, "r") as f: try: data = json.load(f) return cls.model_validate(data) @@ -198,19 +200,19 @@ def from_json(cls, file_path: Union[str, Path]) -> Self: raise ValueError(f"Invalid JSON file: {e}") except ValidationError as e: raise SchemaValidationError(f"Schema validation failed: {e}") - + @classmethod def from_yaml(cls, file_path: Union[str, Path]) -> Self: """ Load a schema configuration from a YAML file. - + Args: file_path (Union[str, Path]): The path to the YAML schema configuration file. - + Returns: SchemaConfig: The loaded schema configuration. """ - with open(file_path, 'r') as f: + with open(file_path, "r") as f: try: data = yaml.safe_load(f) return cls.model_validate(data) @@ -348,11 +350,13 @@ def __init__( ) -> None: super().__init__() self._llm: LLMInterface = llm - self._prompt_template: PromptTemplate = prompt_template or SchemaExtractionTemplate() + self._prompt_template: PromptTemplate = ( + prompt_template or SchemaExtractionTemplate() + ) self._llm_params: dict[str, Any] = llm_params or {} @validate_call - async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig: + async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfig: """ Asynchronously extracts the schema from text and returns a SchemaConfig object. @@ -367,23 +371,27 @@ async def run(self, text: str, examples:str = "", **kwargs: Any) -> SchemaConfig response = await self._llm.invoke(prompt, **self._llm_params) content: str = ( - response if isinstance(response, str) else getattr(response, "content", str(response)) + response + if isinstance(response, str) + else getattr(response, "content", str(response)) ) try: extracted_schema: Dict[str, Any] = json.loads(content) except json.JSONDecodeError as exc: - raise ValueError( - "LLM response is not valid JSON." - ) from exc + raise ValueError("LLM response is not valid JSON.") from exc extracted_entities: List[dict] = extracted_schema.get("entities", []) extracted_relations: Optional[List[dict]] = extracted_schema.get("relations") - potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get("potential_schema") + potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( + "potential_schema" + ) entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities] relations: Optional[List[SchemaRelation]] = ( - [SchemaRelation(**r) for r in extracted_relations] if extracted_relations else None + [SchemaRelation(**r) for r in extracted_relations] + if extracted_relations + else None ) return await super().run( diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index c1b37b946..47c3c180f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -59,7 +59,8 @@ logger = logging.getLogger(__name__) -T = TypeVar('T', bound='SimpleKGPipelineConfig') +T = TypeVar("T", bound="SimpleKGPipelineConfig") + class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ @@ -94,40 +95,47 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): text_splitter: Optional[ComponentType] = None model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode='after') + + @model_validator(mode="after") def handle_schema_precedence(self) -> T: """Handle schema precedence and warnings""" self._process_schema_parameters() return self - + def _process_schema_parameters(self) -> None: """ Process schema parameters and handle precedence between 'schema' parameter and individual components. Also logs warnings for deprecated usage. """ # check if both schema and individual components are provided - has_individual_schema_components = any([self.entities, self.relations, self.potential_schema]) - + has_individual_schema_components = any( + [self.entities, self.relations, self.potential_schema] + ) + if has_individual_schema_components and self.schema is not None: logger.warning( "Both 'schema' and individual schema components (entities, relations, potential_schema) " "were provided. The 'schema' parameter takes precedence. In the future, individual " "components will be removed. Please use only the 'schema' parameter.", - stacklevel=2 + stacklevel=2, ) - + elif has_individual_schema_components: logger.warning( "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " "and will be removed in a future version. " "Please use the 'schema' parameter instead.", - stacklevel=2 + stacklevel=2, ) def has_user_provided_schema(self) -> bool: """Check if the user has provided schema information""" - return bool(self.entities or self.relations or self.potential_schema or self.schema is not None) + return bool( + self.entities + or self.relations + or self.potential_schema + or self.schema is not None + ) def _get_pdf_loader(self) -> Optional[PdfLoader]: if not self.from_pdf: @@ -165,13 +173,17 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]: return SchemaFromText(llm=self.get_default_llm()) return SchemaBuilder() - def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]]: + def _process_schema_with_precedence( + self, + ) -> tuple[ + list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]] + ]: """ Process schema inputs according to precedence rules: 1. If schema is provided as SchemaConfig object, use it 2. If schema is provided as dictionary, extract from it 3. Otherwise, use individual schema components - + Returns: Tuple of (entities, relations, potential_schema) """ @@ -184,15 +196,29 @@ def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[Sche potential_schema = self.schema.potential_schema else: # extract from dictionary - entities = [SchemaEntity.from_text_or_dict(e) for e in self.schema.get("entities", [])] - relations = [SchemaRelation.from_text_or_dict(r) for r in self.schema.get("relations", [])] + entities = [ + SchemaEntity.from_text_or_dict(e) + for e in self.schema.get("entities", []) + ] + relations = [ + SchemaRelation.from_text_or_dict(r) + for r in self.schema.get("relations", []) + ] potential_schema = self.schema.get("potential_schema") else: # use individual components - entities = [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else [] - relations = [SchemaRelation.from_text_or_dict(r) for r in self.relations] if self.relations else [] + entities = ( + [SchemaEntity.from_text_or_dict(e) for e in self.entities] + if self.entities + else [] + ) + relations = ( + [SchemaRelation.from_text_or_dict(r) for r in self.relations] + if self.relations + else [] + ) potential_schema = self.potential_schema - + return entities, relations, potential_schema def _get_run_params_for_schema(self) -> dict[str, Any]: @@ -201,8 +227,10 @@ def _get_run_params_for_schema(self) -> dict[str, Any]: return {} else: # process schema components according to precedence rules - entities, relations, potential_schema = self._process_schema_with_precedence() - + entities, relations, potential_schema = ( + self._process_schema_with_precedence() + ) + return { "entities": entities, "relations": relations, @@ -248,7 +276,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: input_config={"text": "pdf_loader.text"}, ) ) - + # handle automatic schema extraction if self.auto_schema_extraction and not self.has_user_provided_schema(): connections.append( @@ -258,7 +286,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: input_config={"text": "pdf_loader.text"}, ) ) - + connections.append( ConnectionDefinition( start="schema", @@ -279,7 +307,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: input_config={"text": "text"}, # use the original text input ) ) - + connections.append( ConnectionDefinition( start="schema", diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 49b9da3b2..a1583e316 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -47,6 +47,7 @@ logger = logging.getLogger(__name__) + class SimpleKGPipeline: """ A class to simplify the process of building a knowledge graph from text documents. @@ -123,7 +124,9 @@ def __init__( perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, neo4j_database=neo4j_database, - auto_schema_extraction=not bool(schema or entities or relations or potential_schema), + auto_schema_extraction=not bool( + schema or entities or relations or potential_schema + ), ) except (ValidationError, ValueError) as e: raise PipelineDefinitionError() from e diff --git a/src/neo4j_graphrag/generation/__init__.py b/src/neo4j_graphrag/generation/__init__.py index 359d57f96..816fe327a 100644 --- a/src/neo4j_graphrag/generation/__init__.py +++ b/src/neo4j_graphrag/generation/__init__.py @@ -1,9 +1,4 @@ from .graphrag import GraphRAG from .prompts import PromptTemplate, RagTemplate, SchemaExtractionTemplate -__all__ = [ - "GraphRAG", - "PromptTemplate", - "RagTemplate", - "SchemaExtractionTemplate" -] +__all__ = ["GraphRAG", "PromptTemplate", "RagTemplate", "SchemaExtractionTemplate"] diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index cda8d081f..81bf1f565 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -453,7 +453,7 @@ def mock_llm(): @pytest.fixture def valid_schema_json(): - return ''' + return """ { "entities": [ { @@ -481,12 +481,12 @@ def valid_schema_json(): ["Person", "WORKS_FOR", "Organization"] ] } - ''' + """ @pytest.fixture def invalid_schema_json(): - return ''' + return """ { "entities": [ { @@ -495,7 +495,7 @@ def invalid_schema_json(): ], invalid json content } - ''' + """ @pytest.fixture @@ -504,53 +504,57 @@ def schema_from_text(mock_llm): @pytest.mark.asyncio -async def test_schema_from_text_run_valid_response(schema_from_text, mock_llm, valid_schema_json): +async def test_schema_from_text_run_valid_response( + schema_from_text, mock_llm, valid_schema_json +): # configure the mock LLM to return a valid schema JSON mock_llm.invoke.return_value = valid_schema_json - + # run the schema extraction schema_config = await schema_from_text.run(text="Sample text for extraction") - + # verify the LLM was called with a prompt mock_llm.invoke.assert_called_once() prompt_arg = mock_llm.invoke.call_args[0][0] assert isinstance(prompt_arg, str) assert "Sample text for extraction" in prompt_arg - + # verify the schema was correctly extracted assert len(schema_config.entities) == 2 assert "Person" in schema_config.entities assert "Organization" in schema_config.entities - + assert schema_config.relations is not None assert "WORKS_FOR" in schema_config.relations - + assert schema_config.potential_schema is not None assert len(schema_config.potential_schema) == 1 assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") @pytest.mark.asyncio -async def test_schema_from_text_run_invalid_json(schema_from_text, mock_llm, invalid_schema_json): +async def test_schema_from_text_run_invalid_json( + schema_from_text, mock_llm, invalid_schema_json +): # configure the mock LLM to return invalid JSON mock_llm.invoke.return_value = invalid_schema_json - + # verify that running with invalid JSON raises a ValueError with pytest.raises(ValueError) as exc_info: await schema_from_text.run(text="Sample text for extraction") - + assert "not valid JSON" in str(exc_info.value) @pytest.mark.asyncio async def test_schema_from_text_custom_template(mock_llm, valid_schema_json): - # create a custom template + # create a custom template custom_prompt = "This is a custom prompt with text: {text}" custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) - + # create SchemaFromText with the custom template schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template) - + # configure mock LLM to return valid JSON and capture the prompt that was sent to it mock_llm.invoke.return_value = valid_schema_json @@ -575,7 +579,7 @@ async def test_schema_from_text_llm_params(mock_llm, valid_schema_json): # run the schema extraction await schema_from_text.run(text="Sample text") - + # verify the LLM was called with the custom parameters mock_llm.invoke.assert_called_once() call_kwargs = mock_llm.invoke.call_args[1] @@ -588,22 +592,25 @@ async def test_schema_config_store_as_json(schema_config): with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") - + # store the schema config schema_config.store_as_json(json_path) - + # verify the file exists and has content assert os.path.exists(json_path) assert os.path.getsize(json_path) > 0 - + # verify the content is valid JSON and contains expected data - with open(json_path, 'r') as f: + with open(json_path, "r") as f: data = json.load(f) assert "entities" in data assert "PERSON" in data["entities"] assert "properties" in data["entities"]["PERSON"] assert "description" in data["entities"]["PERSON"] - assert data["entities"]["PERSON"]["description"] == "An individual human being." + assert ( + data["entities"]["PERSON"]["description"] + == "An individual human being." + ) @pytest.mark.asyncio @@ -611,22 +618,25 @@ async def test_schema_config_store_as_yaml(schema_config): with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") - + # Store the schema config schema_config.store_as_yaml(yaml_path) - + # Verify the file exists and has content assert os.path.exists(yaml_path) assert os.path.getsize(yaml_path) > 0 - + # Verify the content is valid YAML and contains expected data - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: data = yaml.safe_load(f) assert "entities" in data assert "PERSON" in data["entities"] assert "properties" in data["entities"]["PERSON"] assert "description" in data["entities"]["PERSON"] - assert data["entities"]["PERSON"]["description"] == "An individual human being." + assert ( + data["entities"]["PERSON"]["description"] + == "An individual human being." + ) @pytest.mark.asyncio @@ -636,30 +646,30 @@ async def test_schema_config_from_file(schema_config): json_path = os.path.join(temp_dir, "schema.json") yaml_path = os.path.join(temp_dir, "schema.yaml") yml_path = os.path.join(temp_dir, "schema.yml") - + # store the schema config in the different formats schema_config.store_as_json(json_path) schema_config.store_as_yaml(yaml_path) schema_config.store_as_yaml(yml_path) - + # load using from_file which should detect the format based on extension json_schema = SchemaConfig.from_file(json_path) yaml_schema = SchemaConfig.from_file(yaml_path) yml_schema = SchemaConfig.from_file(yml_path) - + # simple verification that the objects were loaded correctly assert isinstance(json_schema, SchemaConfig) assert isinstance(yaml_schema, SchemaConfig) assert isinstance(yml_schema, SchemaConfig) - + # verify basic structure is intact assert "entities" in json_schema.model_dump() assert "entities" in yaml_schema.model_dump() assert "entities" in yml_schema.model_dump() - + # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") schema_config.store_as_json(txt_path) # Store as JSON but with .txt extension - + with pytest.raises(ValueError, match="Unsupported file format"): SchemaConfig.from_file(txt_path) From b19e57cc6594f8c78992ee25640d98f9eecf182b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 10:54:23 +0300 Subject: [PATCH 12/36] Fix mypy issues --- .../experimental/components/schema.py | 20 ++++++------ .../template_pipeline/simple_kg_builder.py | 2 +- .../experimental/pipeline/kg_builder.py | 4 +-- .../experimental/components/test_schema.py | 32 +++++++++++-------- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 2366df4ea..c39b21a05 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -20,7 +20,6 @@ from pathlib import Path from pydantic import BaseModel, ValidationError, model_validator, validate_call -from requests.exceptions import InvalidJSONError from typing_extensions import Self from neo4j_graphrag.exceptions import SchemaValidationError @@ -336,10 +335,10 @@ async def run( return self.create_schema_model(entities, relations, potential_schema) -class SchemaFromText(SchemaBuilder): +class SchemaFromText(Component): """ - A builder class for constructing SchemaConfig objects from the output of an LLM after - automatic schema extraction from text. + A component for constructing SchemaConfig objects from the output of an LLM after + automatic schema extraction from text. """ def __init__( @@ -348,7 +347,6 @@ def __init__( prompt_template: Optional[PromptTemplate] = None, llm_params: Optional[Dict[str, Any]] = None, ) -> None: - super().__init__() self._llm: LLMInterface = llm self._prompt_template: PromptTemplate = ( prompt_template or SchemaExtractionTemplate() @@ -369,7 +367,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi """ prompt: str = self._prompt_template.format(text=text, examples=examples) - response = await self._llm.invoke(prompt, **self._llm_params) + response = await self._llm.ainvoke(prompt, **self._llm_params) content: str = ( response if isinstance(response, str) @@ -381,8 +379,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi except json.JSONDecodeError as exc: raise ValueError("LLM response is not valid JSON.") from exc - extracted_entities: List[dict] = extracted_schema.get("entities", []) - extracted_relations: Optional[List[dict]] = extracted_schema.get("relations") + extracted_entities: List[Dict[str, Any]] = ( + extracted_schema.get("entities") or [] + ) + extracted_relations: Optional[List[Dict[str, Any]]] = extracted_schema.get( + "relations" + ) potential_schema: Optional[List[Tuple[str, str, str]]] = extracted_schema.get( "potential_schema" ) @@ -394,7 +396,7 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi else None ) - return await super().run( + return SchemaBuilder.create_schema_model( entities=entities, relations=relations, potential_schema=potential_schema, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 47c3c180f..7ebbace26 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -81,7 +81,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema: Optional[Union[SchemaConfig, dict[str, list]]] = None + schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index a1583e316..c3159a681 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, Any import logging import neo4j @@ -92,7 +92,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, - schema: Optional[Union[SchemaConfig, dict[str, list]]] = None, + schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None, enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 81bf1f565..50e47d9de 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -110,7 +110,7 @@ def schema_config( valid_entities: list[SchemaEntity], valid_relations: list[SchemaRelation], potential_schema: list[tuple[str, str, str]], -): +) -> SchemaConfig: return schema_builder.create_schema_model( valid_entities, valid_relations, potential_schema ) @@ -445,14 +445,14 @@ def test_create_schema_model_missing_relations( @pytest.fixture -def mock_llm(): +def mock_llm() -> AsyncMock: mock = AsyncMock() mock.invoke = AsyncMock() return mock @pytest.fixture -def valid_schema_json(): +def valid_schema_json() -> str: return """ { "entities": [ @@ -485,7 +485,7 @@ def valid_schema_json(): @pytest.fixture -def invalid_schema_json(): +def invalid_schema_json() -> str: return """ { "entities": [ @@ -499,14 +499,14 @@ def invalid_schema_json(): @pytest.fixture -def schema_from_text(mock_llm): +def schema_from_text(mock_llm: AsyncMock) -> SchemaFromText: return SchemaFromText(llm=mock_llm) @pytest.mark.asyncio async def test_schema_from_text_run_valid_response( - schema_from_text, mock_llm, valid_schema_json -): + schema_from_text: SchemaFromText, mock_llm: AsyncMock, valid_schema_json: str +) -> None: # configure the mock LLM to return a valid schema JSON mock_llm.invoke.return_value = valid_schema_json @@ -534,8 +534,8 @@ async def test_schema_from_text_run_valid_response( @pytest.mark.asyncio async def test_schema_from_text_run_invalid_json( - schema_from_text, mock_llm, invalid_schema_json -): + schema_from_text: SchemaFromText, mock_llm: AsyncMock, invalid_schema_json: str +) -> None: # configure the mock LLM to return invalid JSON mock_llm.invoke.return_value = invalid_schema_json @@ -547,7 +547,9 @@ async def test_schema_from_text_run_invalid_json( @pytest.mark.asyncio -async def test_schema_from_text_custom_template(mock_llm, valid_schema_json): +async def test_schema_from_text_custom_template( + mock_llm: AsyncMock, valid_schema_json: str +) -> None: # create a custom template custom_prompt = "This is a custom prompt with text: {text}" custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) @@ -567,7 +569,9 @@ async def test_schema_from_text_custom_template(mock_llm, valid_schema_json): @pytest.mark.asyncio -async def test_schema_from_text_llm_params(mock_llm, valid_schema_json): +async def test_schema_from_text_llm_params( + mock_llm: AsyncMock, valid_schema_json: str +) -> None: # configure custom LLM parameters llm_params = {"temperature": 0.1, "max_tokens": 500} @@ -588,7 +592,7 @@ async def test_schema_from_text_llm_params(mock_llm, valid_schema_json): @pytest.mark.asyncio -async def test_schema_config_store_as_json(schema_config): +async def test_schema_config_store_as_json(schema_config: SchemaConfig) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") @@ -614,7 +618,7 @@ async def test_schema_config_store_as_json(schema_config): @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(schema_config): +async def test_schema_config_store_as_yaml(schema_config: SchemaConfig) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") @@ -640,7 +644,7 @@ async def test_schema_config_store_as_yaml(schema_config): @pytest.mark.asyncio -async def test_schema_config_from_file(schema_config): +async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file paths with different extensions json_path = os.path.join(temp_dir, "schema.json") From 4eebee563c510e755f2646aa5b4f03c61af1f38b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 11:02:03 +0300 Subject: [PATCH 13/36] Ignore remaining mypy issues (temp) --- .../config/template_pipeline/simple_kg_builder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 7ebbace26..1a46829f2 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -81,7 +81,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None + schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None # type: ignore enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -97,10 +97,10 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="after") - def handle_schema_precedence(self) -> T: + def handle_schema_precedence(self) -> T: # type: ignore """Handle schema precedence and warnings""" self._process_schema_parameters() - return self + return self # type: ignore def _process_schema_parameters(self) -> None: """ @@ -192,12 +192,12 @@ def _process_schema_with_precedence( if isinstance(self.schema, SchemaConfig): # extract components from SchemaConfig entities = list(self.schema.entities.values()) - relations = list(self.schema.relations.values()) + relations = list(self.schema.relations.values()) # type: ignore potential_schema = self.schema.potential_schema else: # extract from dictionary entities = [ - SchemaEntity.from_text_or_dict(e) + SchemaEntity.from_text_or_dict(e) # type: ignore for e in self.schema.get("entities", []) ] relations = [ @@ -208,7 +208,7 @@ def _process_schema_with_precedence( else: # use individual components entities = ( - [SchemaEntity.from_text_or_dict(e) for e in self.entities] + [SchemaEntity.from_text_or_dict(e) for e in self.entities] # type: ignore if self.entities else [] ) @@ -219,7 +219,7 @@ def _process_schema_with_precedence( ) potential_schema = self.potential_schema - return entities, relations, potential_schema + return entities, relations, potential_schema # type: ignore def _get_run_params_for_schema(self) -> dict[str, Any]: if self.auto_schema_extraction and not self.has_user_provided_schema(): From 7088286df7b23804e0ca4a8eefabeccd5ab43d20 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 11:42:24 +0300 Subject: [PATCH 14/36] Remove unused imports --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 2 +- tests/unit/experimental/components/test_schema.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 1a46829f2..1638633a9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -15,7 +15,7 @@ from typing import Any, ClassVar, Literal, Optional, Sequence, Union, TypeVar import logging -from pydantic import ConfigDict, model_validator, Field +from pydantic import ConfigDict, model_validator from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 50e47d9de..1ea9311ab 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -31,7 +31,6 @@ import os import tempfile import yaml -from pathlib import Path from neo4j_graphrag.generation import PromptTemplate From 9d05c76161cf59eecb8997a3383b4c821e9c8282 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 12:22:50 +0300 Subject: [PATCH 15/36] Fix unit tests --- .../experimental/components/test_schema.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 1ea9311ab..7fdf97274 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -446,7 +446,7 @@ def test_create_schema_model_missing_relations( @pytest.fixture def mock_llm() -> AsyncMock: mock = AsyncMock() - mock.invoke = AsyncMock() + mock.ainvoke = AsyncMock() return mock @@ -507,14 +507,14 @@ async def test_schema_from_text_run_valid_response( schema_from_text: SchemaFromText, mock_llm: AsyncMock, valid_schema_json: str ) -> None: # configure the mock LLM to return a valid schema JSON - mock_llm.invoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = valid_schema_json # run the schema extraction schema_config = await schema_from_text.run(text="Sample text for extraction") # verify the LLM was called with a prompt - mock_llm.invoke.assert_called_once() - prompt_arg = mock_llm.invoke.call_args[0][0] + mock_llm.ainvoke.assert_called_once() + prompt_arg = mock_llm.ainvoke.call_args[0][0] assert isinstance(prompt_arg, str) assert "Sample text for extraction" in prompt_arg @@ -536,7 +536,7 @@ async def test_schema_from_text_run_invalid_json( schema_from_text: SchemaFromText, mock_llm: AsyncMock, invalid_schema_json: str ) -> None: # configure the mock LLM to return invalid JSON - mock_llm.invoke.return_value = invalid_schema_json + mock_llm.ainvoke.return_value = invalid_schema_json # verify that running with invalid JSON raises a ValueError with pytest.raises(ValueError) as exc_info: @@ -557,13 +557,13 @@ async def test_schema_from_text_custom_template( schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template) # configure mock LLM to return valid JSON and capture the prompt that was sent to it - mock_llm.invoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = valid_schema_json # run the schema extraction await schema_from_text.run(text="Sample text") # verify the custom prompt was passed to the LLM - prompt_sent_to_llm = mock_llm.invoke.call_args[0][0] + prompt_sent_to_llm = mock_llm.ainvoke.call_args[0][0] assert "This is a custom prompt with text" in prompt_sent_to_llm @@ -578,14 +578,14 @@ async def test_schema_from_text_llm_params( schema_from_text = SchemaFromText(llm=mock_llm, llm_params=llm_params) # configure the mock LLM to return a valid schema JSON - mock_llm.invoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = valid_schema_json # run the schema extraction await schema_from_text.run(text="Sample text") # verify the LLM was called with the custom parameters - mock_llm.invoke.assert_called_once() - call_kwargs = mock_llm.invoke.call_args[1] + mock_llm.ainvoke.assert_called_once() + call_kwargs = mock_llm.ainvoke.call_args[1] assert call_kwargs["temperature"] == 0.1 assert call_kwargs["max_tokens"] == 500 From f9a7c8cf18267e7eb17133861acb5e628ebd2079 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 13:04:06 +0300 Subject: [PATCH 16/36] Fix component connections --- .../config/template_pipeline/simple_kg_builder.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 1638633a9..228cce8ef 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -298,16 +298,6 @@ def _get_connections(self) -> list[ConnectionDefinition]: ) ) else: - # handle automatic schema extraction for direct text input: ensure schema extraction uses the complete text - if self.auto_schema_extraction and not self.has_user_provided_schema(): - connections.append( - ConnectionDefinition( - start="__input__", # connection to pipeline input - end="schema", - input_config={"text": "text"}, # use the original text input - ) - ) - connections.append( ConnectionDefinition( start="schema", @@ -379,4 +369,7 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: "Expected 'text' argument when 'from_pdf' is False." ) run_params["splitter"] = {"text": text} + # Add full text to schema component for automatic schema extraction + if self.auto_schema_extraction and not self.has_user_provided_schema(): + run_params["schema"] = {"text": text} return run_params From 8458b75b877d61c796795e2f51fc45010e5c360b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 29 Apr 2025 17:23:14 +0300 Subject: [PATCH 17/36] Improve default schema extraction prompt and add examples --- examples/README.md | 2 + .../schema_from_text.py | 130 ++++++++++++++++++ src/neo4j_graphrag/generation/prompts.py | 58 +++++--- 3 files changed, 173 insertions(+), 17 deletions(-) create mode 100644 examples/automatic_schema_extraction/schema_from_text.py diff --git a/examples/README.md b/examples/README.md index 7feb71f3a..fa8bb945e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,6 +3,7 @@ This folder contains examples usage for the different features supported by the `neo4j-graphrag` package: +- [Automatic Schema Extraction](#schema-extraction) from PDF or text - [Build Knowledge Graph](#build-knowledge-graph) from PDF or text - [Retrieve](#retrieve) information from the graph - [Question Answering](#answer-graphrag) (Q&A) @@ -122,6 +123,7 @@ are listed in [the last section of this file](#customize). - [Chunk embedder]() - Schema Builder: - [User-defined](./customize/build_graph/components/schema_builders/schema.py) + - [Automatic schema extraction](./automatic_schema_extraction/schema_from_text.py) - Entity Relation Extractor: - [LLM-based](./customize/build_graph/components/extractors/llm_entity_relation_extractor.py) - [LLM-based with custom prompt](./customize/build_graph/components/extractors/llm_entity_relation_extractor_with_custom_prompt.py) diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/automatic_schema_extraction/schema_from_text.py new file mode 100644 index 000000000..65979aa5e --- /dev/null +++ b/examples/automatic_schema_extraction/schema_from_text.py @@ -0,0 +1,130 @@ +"""This example demonstrates how to use the SchemaFromText component +to automatically extract a schema from text and save it to JSON and YAML files. + +The SchemaFromText component uses an LLM to analyze the text and identify entities, +relations, and their properties. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +import os +from dotenv import load_dotenv + +from neo4j_graphrag.experimental.components.schema import SchemaFromText, SchemaConfig +from neo4j_graphrag.llm import OpenAILLM + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) + +# Verify OpenAI API key is available +if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "OPENAI_API_KEY environment variable not found. " + "Please set it in the .env file in the root directory." + ) + +# Sample text to extract schema from - it's about a company and its employees +TEXT = """ +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets +for the consumer electronics industry. + +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +and has filed 5 patents during her time at Acme. + +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports +directly to the CEO, Michael Brown, who took over leadership in 2008. + +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led +by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. + +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 +with an estimated valuation of $500 million. +""" + +# Define the file paths for saving the schema +OUTPUT_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") +JSON_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.json") +YAML_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.yaml") + + +async def extract_and_save_schema() -> SchemaConfig: + """Extract schema from text and save it to JSON and YAML files.""" + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + try: + # Create a SchemaFromText component with the default template + schema_extractor = SchemaFromText(llm=llm) + + print("Extracting schema from text...") + # Extract schema from text + inferred_schema = await schema_extractor.run(text=TEXT) + + # Ensure the output directory exists + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print(f"Saving schema to JSON file: {JSON_FILE_PATH}") + # Save the schema to JSON file + inferred_schema.store_as_json(JSON_FILE_PATH) + + print(f"Saving schema to YAML file: {YAML_FILE_PATH}") + # Save the schema to YAML file + inferred_schema.store_as_yaml(YAML_FILE_PATH) + + print("\nExtracted Schema Summary:") + print(f"Entities: {list(inferred_schema.entities.keys())}") + print(f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}") + + if inferred_schema.potential_schema: + print("\nPotential Schema:") + for entity1, relation, entity2 in inferred_schema.potential_schema: + print(f" {entity1} --[{relation}]--> {entity2}") + + return inferred_schema + + finally: + # Close the LLM client + await llm.async_client.close() + + +async def main() -> None: + """Run the example.""" + + # Extract schema and save to files + schema_config = await extract_and_save_schema() + + print(f"\nSchema files have been saved to:") + print(f" - JSON: {JSON_FILE_PATH}") + print(f" - YAML: {YAML_FILE_PATH}") + + print("\nExample of how to load the schema from files:") + print(" from neo4j_graphrag.experimental.components.schema import SchemaConfig") + print(f" schema_from_json = SchemaConfig.from_file('{JSON_FILE_PATH}')") + print(f" schema_from_yaml = SchemaConfig.from_file('{YAML_FILE_PATH}')") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 6ed72463b..96bcaf8de 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -207,24 +207,48 @@ class SchemaExtractionTemplate(PromptTemplate): You are a top-tier algorithm designed for extracting a labeled property graph schema in structured formats. -Generate the generalized graph schema based on input text. Identify key entity types, -their relationship types, and property types whenever it is possible. Return only -abstract schema information, no concrete instances. Use singular PascalCase labels for -entity types and UPPER_SNAKE_CASE for relationship types. Include property definitions -only when the type can be confidently inferred, otherwise omit the properties. +Generate a generalized graph schema based on the input text. Identify key entity types, +their relationship types, and property types. + +IMPORTANT RULES: +1. Return only abstract schema information, not concrete instances. +2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product). +3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES). +4. Include property definitions only when the type can be confidently inferred, otherwise omit them. +5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists. +6. Do not create entity types that aren't clearly mentioned in the text. +7. Keep your schema minimal and focused on clearly identifiable patterns in the text. + Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST, -LOCAL DATETIME, LOCAL TIME, POINT, STRING, ZONED DATETIME, ZONED TIME. -Do not add extra keys or explanatory text. Return a valid JSON object without -back‑ticks, markdown, or comments. - -For example, if the text says "Alice lives in London", the output JSON object should -adhere to the following format: -{{"entities": [{{"label": "Person", "properties": [{{"name": "name", "type": "STRING"}}]}}, -{{"label": "City", "properties":[{{"name": "name", "type": "STRING"}}]}}], -"relations": [{{"label": "LIVES_IN"}}], -"potential_schema":[[ "Person", "LIVES_IN", "City"]]}} - -More examples: +LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME. + +Return a valid JSON object that follows this precise structure: +{{ + "entities": [ + {{ + "label": "Person", + "properties": [ + {{ + "name": "name", + "type": "STRING" + }} + ] + }}, + ... + ], + "relations": [ + {{ + "label": "WORKS_FOR" + }}, + ... + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Company"], + ... + ] +}} + +Examples: {examples} Input text: From 7558b562c4c3f337a3a9eb8b718783c076c9ec98 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 30 Apr 2025 15:21:14 +0300 Subject: [PATCH 18/36] Rename schema from text component --- docs/source/api.rst | 6 +-- docs/source/user_guide_kg_builder.rst | 22 +++----- .../schema_from_text.py | 51 +++++++++++-------- .../experimental/components/schema.py | 2 +- .../template_pipeline/simple_kg_builder.py | 10 ++-- .../experimental/components/test_schema.py | 24 +++++---- 6 files changed, 60 insertions(+), 55 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 2567411af..55a5d1cc4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -77,10 +77,10 @@ SchemaBuilder .. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaBuilder :members: run -SchemaFromText -============= +SchemaFromTextExtractor +----------------------- -.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromText +.. autoclass:: neo4j_graphrag.experimental.components.schema.SchemaFromTextExtractor :members: run EntityRelationExtractor diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 3c4716176..a191600b8 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -139,7 +139,7 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated ) .. note:: - By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromText` section. + By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromTextExtractor`). Extra configurations -------------------- @@ -817,19 +817,18 @@ Here is a code block illustrating these concepts: After validation, this schema is saved in a `SchemaConfig` object, whose dict representation is passed to the LLM. -Automatic Schema Extraction with SchemaFromText ----------------------------------------------- -.. _automatic-schema-extraction: +Automatic Schema Extraction +--------------------------- -Instead of manually defining the schema, you can use the `SchemaFromText` component to automatically extract a schema from your text using an LLM: +Instead of manually defining the schema, you can use the `SchemaFromTextExtractor` component to automatically extract a schema from your text using an LLM: .. code:: python - from neo4j_graphrag.experimental.components.schema import SchemaFromText + from neo4j_graphrag.experimental.components.schema import SchemaFromTextExtractor from neo4j_graphrag.llm import OpenAILLM # Create the automatic schema extractor - schema_extractor = SchemaFromText( + schema_extractor = SchemaFromTextExtractor( llm=OpenAILLM( model_name="gpt-4o", model_params={ @@ -839,14 +838,7 @@ Instead of manually defining the schema, you can use the `SchemaFromText` compon ) ) - # Extract schema from text - schema_config = await schema_extractor.run(text="Your document text here...") - - # Use the extracted schema with other components - extractor = LLMEntityRelationExtractor(llm=llm) - result = await extractor.run(chunks=chunks, schema=schema_config) - -The `SchemaFromText` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. +The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/automatic_schema_extraction/schema_from_text.py index 65979aa5e..2fe0e8e7d 100644 --- a/examples/automatic_schema_extraction/schema_from_text.py +++ b/examples/automatic_schema_extraction/schema_from_text.py @@ -1,7 +1,7 @@ -"""This example demonstrates how to use the SchemaFromText component +"""This example demonstrates how to use the SchemaFromTextExtractor component to automatically extract a schema from text and save it to JSON and YAML files. -The SchemaFromText component uses an LLM to analyze the text and identify entities, +The SchemaFromTextExtractor component uses an LLM to analyze the text and identify entities, relations, and their properties. Note: This example requires an OpenAI API key to be set in the .env file. @@ -12,7 +12,10 @@ import os from dotenv import load_dotenv -from neo4j_graphrag.experimental.components.schema import SchemaFromText, SchemaConfig +from neo4j_graphrag.experimental.components.schema import ( + SchemaFromTextExtractor, + SchemaConfig, +) from neo4j_graphrag.llm import OpenAILLM # Load environment variables from .env file @@ -54,57 +57,61 @@ """ # Define the file paths for saving the schema -OUTPUT_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data" +) JSON_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.json") YAML_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.yaml") async def extract_and_save_schema() -> SchemaConfig: """Extract schema from text and save it to JSON and YAML files.""" - + # Define LLM parameters llm_model_params = { "max_tokens": 2000, "response_format": {"type": "json_object"}, "temperature": 0, # Lower temperature for more consistent output } - + # Create the LLM instance llm = OpenAILLM( model_name="gpt-4o", model_params=llm_model_params, ) - + try: - # Create a SchemaFromText component with the default template - schema_extractor = SchemaFromText(llm=llm) - + # Create a SchemaFromTextExtractor component with the default template + schema_extractor = SchemaFromTextExtractor(llm=llm) + print("Extracting schema from text...") # Extract schema from text inferred_schema = await schema_extractor.run(text=TEXT) - + # Ensure the output directory exists os.makedirs(OUTPUT_DIR, exist_ok=True) - + print(f"Saving schema to JSON file: {JSON_FILE_PATH}") # Save the schema to JSON file inferred_schema.store_as_json(JSON_FILE_PATH) - + print(f"Saving schema to YAML file: {YAML_FILE_PATH}") # Save the schema to YAML file inferred_schema.store_as_yaml(YAML_FILE_PATH) - + print("\nExtracted Schema Summary:") print(f"Entities: {list(inferred_schema.entities.keys())}") - print(f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}") - + print( + f"Relations: {list(inferred_schema.relations.keys() if inferred_schema.relations else [])}" + ) + if inferred_schema.potential_schema: print("\nPotential Schema:") for entity1, relation, entity2 in inferred_schema.potential_schema: print(f" {entity1} --[{relation}]--> {entity2}") - + return inferred_schema - + finally: # Close the LLM client await llm.async_client.close() @@ -112,14 +119,14 @@ async def extract_and_save_schema() -> SchemaConfig: async def main() -> None: """Run the example.""" - + # Extract schema and save to files schema_config = await extract_and_save_schema() - + print(f"\nSchema files have been saved to:") print(f" - JSON: {JSON_FILE_PATH}") print(f" - YAML: {YAML_FILE_PATH}") - + print("\nExample of how to load the schema from files:") print(" from neo4j_graphrag.experimental.components.schema import SchemaConfig") print(f" schema_from_json = SchemaConfig.from_file('{JSON_FILE_PATH}')") @@ -127,4 +134,4 @@ async def main() -> None: if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index c39b21a05..5b789564b 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -335,7 +335,7 @@ async def run( return self.create_schema_model(entities, relations, potential_schema) -class SchemaFromText(Component): +class SchemaFromTextExtractor(Component): """ A component for constructing SchemaConfig objects from the output of an LLM after automatic schema extraction from text. diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 228cce8ef..a065bb444 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -31,10 +31,10 @@ ) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, + SchemaConfig, SchemaEntity, SchemaRelation, - SchemaFromText, - SchemaConfig, + SchemaFromTextExtractor, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -164,13 +164,13 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]: def _get_chunk_embedder(self) -> TextChunkEmbedder: return TextChunkEmbedder(embedder=self.get_default_embedder()) - def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]: + def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: """ Get the appropriate schema component based on configuration. - Return SchemaFromText for automatic extraction or SchemaBuilder for manual schema. + Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema. """ if self.auto_schema_extraction and not self.has_user_provided_schema(): - return SchemaFromText(llm=self.get_default_llm()) + return SchemaFromTextExtractor(llm=self.get_default_llm()) return SchemaBuilder() def _process_schema_with_precedence( diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 7fdf97274..1ac91f154 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -24,7 +24,7 @@ SchemaEntity, SchemaProperty, SchemaRelation, - SchemaFromText, + SchemaFromTextExtractor, SchemaConfig, ) from pydantic import ValidationError @@ -498,13 +498,15 @@ def invalid_schema_json() -> str: @pytest.fixture -def schema_from_text(mock_llm: AsyncMock) -> SchemaFromText: - return SchemaFromText(llm=mock_llm) +def schema_from_text(mock_llm: AsyncMock) -> SchemaFromTextExtractor: + return SchemaFromTextExtractor(llm=mock_llm) @pytest.mark.asyncio async def test_schema_from_text_run_valid_response( - schema_from_text: SchemaFromText, mock_llm: AsyncMock, valid_schema_json: str + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json: str, ) -> None: # configure the mock LLM to return a valid schema JSON mock_llm.ainvoke.return_value = valid_schema_json @@ -533,7 +535,9 @@ async def test_schema_from_text_run_valid_response( @pytest.mark.asyncio async def test_schema_from_text_run_invalid_json( - schema_from_text: SchemaFromText, mock_llm: AsyncMock, invalid_schema_json: str + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + invalid_schema_json: str, ) -> None: # configure the mock LLM to return invalid JSON mock_llm.ainvoke.return_value = invalid_schema_json @@ -553,8 +557,10 @@ async def test_schema_from_text_custom_template( custom_prompt = "This is a custom prompt with text: {text}" custom_template = PromptTemplate(template=custom_prompt, expected_inputs=["text"]) - # create SchemaFromText with the custom template - schema_from_text = SchemaFromText(llm=mock_llm, prompt_template=custom_template) + # create SchemaFromTextExtractor with the custom template + schema_from_text = SchemaFromTextExtractor( + llm=mock_llm, prompt_template=custom_template + ) # configure mock LLM to return valid JSON and capture the prompt that was sent to it mock_llm.ainvoke.return_value = valid_schema_json @@ -574,8 +580,8 @@ async def test_schema_from_text_llm_params( # configure custom LLM parameters llm_params = {"temperature": 0.1, "max_tokens": 500} - # create SchemaFromText with custom LLM parameters - schema_from_text = SchemaFromText(llm=mock_llm, llm_params=llm_params) + # create SchemaFromTextExtractor with custom LLM parameters + schema_from_text = SchemaFromTextExtractor(llm=mock_llm, llm_params=llm_params) # configure the mock LLM to return a valid schema JSON mock_llm.ainvoke.return_value = valid_schema_json From 8885e2ccbd2d0faf207d85a8d566f2cea16eeb1c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 5 May 2025 12:57:28 +0200 Subject: [PATCH 19/36] Fix remaining mypy errors --- .../template_pipeline/simple_kg_builder.py | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index a065bb444..348ad7d0e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, Literal, Optional, Sequence, Union, TypeVar +from typing import Any, ClassVar, Literal, Optional, Sequence, Union, List, Tuple import logging -from pydantic import ConfigDict, model_validator +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -59,8 +60,6 @@ logger = logging.getLogger(__name__) -T = TypeVar("T", bound="SimpleKGPipelineConfig") - class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ @@ -81,7 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = None # type: ignore + schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field(default=None, alias="schema") enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -97,10 +96,10 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="after") - def handle_schema_precedence(self) -> T: # type: ignore + def handle_schema_precedence(self) -> Self: """Handle schema precedence and warnings""" self._process_schema_parameters() - return self # type: ignore + return self def _process_schema_parameters(self) -> None: """ @@ -112,7 +111,7 @@ def _process_schema_parameters(self) -> None: [self.entities, self.relations, self.potential_schema] ) - if has_individual_schema_components and self.schema is not None: + if has_individual_schema_components and self.schema_ is not None: logger.warning( "Both 'schema' and individual schema components (entities, relations, potential_schema) " "were provided. The 'schema' parameter takes precedence. In the future, individual " @@ -134,7 +133,7 @@ def has_user_provided_schema(self) -> bool: self.entities or self.relations or self.potential_schema - or self.schema is not None + or self.schema_ is not None ) def _get_pdf_loader(self) -> Optional[PdfLoader]: @@ -175,8 +174,8 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: def _process_schema_with_precedence( self, - ) -> tuple[ - list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]] + ) -> Tuple[ + List[SchemaEntity], List[SchemaRelation], Optional[List[Tuple[str, str, str]]] ]: """ Process schema inputs according to precedence rules: @@ -187,28 +186,37 @@ def _process_schema_with_precedence( Returns: Tuple of (entities, relations, potential_schema) """ - if self.schema is not None: + if self.schema_ is not None: # schema takes precedence over individual components - if isinstance(self.schema, SchemaConfig): + if isinstance(self.schema_, SchemaConfig): # extract components from SchemaConfig - entities = list(self.schema.entities.values()) - relations = list(self.schema.relations.values()) # type: ignore - potential_schema = self.schema.potential_schema + entity_dicts = list(self.schema_.entities.values()) + # convert dict values to SchemaEntity objects + entities = [SchemaEntity.model_validate(e) for e in entity_dicts] + + # handle case where relations could be None + if self.schema_.relations is not None: + relation_dicts = list(self.schema_.relations.values()) + relations = [SchemaRelation.model_validate(r) for r in relation_dicts] + else: + relations = [] + + potential_schema = self.schema_.potential_schema else: # extract from dictionary entities = [ - SchemaEntity.from_text_or_dict(e) # type: ignore - for e in self.schema.get("entities", []) + SchemaEntity.from_text_or_dict(e) + for e in self.schema_.get("entities", []) ] relations = [ SchemaRelation.from_text_or_dict(r) - for r in self.schema.get("relations", []) + for r in self.schema_.get("relations", []) ] - potential_schema = self.schema.get("potential_schema") + potential_schema = self.schema_.get("potential_schema") else: # use individual components entities = ( - [SchemaEntity.from_text_or_dict(e) for e in self.entities] # type: ignore + [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else [] ) @@ -219,7 +227,7 @@ def _process_schema_with_precedence( ) potential_schema = self.potential_schema - return entities, relations, potential_schema # type: ignore + return entities, relations, potential_schema def _get_run_params_for_schema(self) -> dict[str, Any]: if self.auto_schema_extraction and not self.has_user_provided_schema(): From 78633c6678e60462a3f12ec3270b871902c0fcf2 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 5 May 2025 13:24:48 +0200 Subject: [PATCH 20/36] Improve schema from text example --- .../schema_from_text.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/automatic_schema_extraction/schema_from_text.py index 2fe0e8e7d..f08e17c34 100644 --- a/examples/automatic_schema_extraction/schema_from_text.py +++ b/examples/automatic_schema_extraction/schema_from_text.py @@ -64,7 +64,7 @@ YAML_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.yaml") -async def extract_and_save_schema() -> SchemaConfig: +async def extract_and_save_schema() -> None: """Extract schema from text and save it to JSON and YAML files.""" # Define LLM parameters @@ -110,8 +110,6 @@ async def extract_and_save_schema() -> SchemaConfig: for entity1, relation, entity2 in inferred_schema.potential_schema: print(f" {entity1} --[{relation}]--> {entity2}") - return inferred_schema - finally: # Close the LLM client await llm.async_client.close() @@ -120,17 +118,20 @@ async def extract_and_save_schema() -> SchemaConfig: async def main() -> None: """Run the example.""" - # Extract schema and save to files - schema_config = await extract_and_save_schema() + # extract schema and save to files + await extract_and_save_schema() - print(f"\nSchema files have been saved to:") + print("\nSchema files have been saved to:") print(f" - JSON: {JSON_FILE_PATH}") print(f" - YAML: {YAML_FILE_PATH}") - print("\nExample of how to load the schema from files:") - print(" from neo4j_graphrag.experimental.components.schema import SchemaConfig") - print(f" schema_from_json = SchemaConfig.from_file('{JSON_FILE_PATH}')") - print(f" schema_from_yaml = SchemaConfig.from_file('{YAML_FILE_PATH}')") + # load schema from files + print("\nLoading schemas from saved files:") + schema_from_json = SchemaConfig.from_file(JSON_FILE_PATH) + schema_from_yaml = SchemaConfig.from_file(YAML_FILE_PATH) + + print(f"Entities in JSON schema: {list(schema_from_json.entities.keys())}") + print(f"Entities in YAML schema: {list(schema_from_yaml.entities.keys())}") if __name__ == "__main__": From fef2e495b88e1d42bf778519ab817ffdaee9abdf Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 5 May 2025 15:07:59 +0200 Subject: [PATCH 21/36] Ruff --- .../automatic_schema_extraction/schema_from_text.py | 2 +- .../config/template_pipeline/simple_kg_builder.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/automatic_schema_extraction/schema_from_text.py index f08e17c34..b20431fd2 100644 --- a/examples/automatic_schema_extraction/schema_from_text.py +++ b/examples/automatic_schema_extraction/schema_from_text.py @@ -129,7 +129,7 @@ async def main() -> None: print("\nLoading schemas from saved files:") schema_from_json = SchemaConfig.from_file(JSON_FILE_PATH) schema_from_yaml = SchemaConfig.from_file(YAML_FILE_PATH) - + print(f"Entities in JSON schema: {list(schema_from_json.entities.keys())}") print(f"Entities in YAML schema: {list(schema_from_yaml.entities.keys())}") diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 348ad7d0e..065814ebb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -80,7 +80,9 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field(default=None, alias="schema") + schema_: Optional[Union[SchemaConfig, dict[str, list[Any]]]] = Field( + default=None, alias="schema" + ) enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -193,14 +195,16 @@ def _process_schema_with_precedence( entity_dicts = list(self.schema_.entities.values()) # convert dict values to SchemaEntity objects entities = [SchemaEntity.model_validate(e) for e in entity_dicts] - + # handle case where relations could be None if self.schema_.relations is not None: relation_dicts = list(self.schema_.relations.values()) - relations = [SchemaRelation.model_validate(r) for r in relation_dicts] + relations = [ + SchemaRelation.model_validate(r) for r in relation_dicts + ] else: relations = [] - + potential_schema = self.schema_.potential_schema else: # extract from dictionary From b412a05bb32685df34db958897d7aa5abdcdadb8 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 5 May 2025 15:48:19 +0200 Subject: [PATCH 22/36] Remove flag for automatic schema extraction --- CHANGELOG.md | 3 +-- docs/source/user_guide_kg_builder.rst | 4 ---- .../config/template_pipeline/simple_kg_builder.py | 9 ++++----- src/neo4j_graphrag/experimental/pipeline/kg_builder.py | 3 --- 4 files changed, 5 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6d17340a..b10242e87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,7 @@ ### Added -- Added support for automatic schema extraction from text using LLMs. - +- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. ## 1.7.0 ### Added diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index a191600b8..c263d5083 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -466,8 +466,6 @@ within the configuration file. ["House", "RULES", "Planet"] ] }, - /* Control automatic schema extraction */ - "auto_schema_extraction": false, "lexical_graph_config": { "chunk_node_label": "TextPart" } @@ -511,8 +509,6 @@ or in YAML: - ["Person", "PARENT_OF", "Person"] - ["Person", "HEIR_OF", "House"] - ["House", "RULES", "Planet"] - # Control automatic schema extraction - auto_schema_extraction: false lexical_graph_config: chunk_node_label: TextPart diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 065814ebb..a5006d89d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -89,7 +89,6 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): perform_entity_resolution: bool = True lexical_graph_config: Optional[LexicalGraphConfig] = None neo4j_database: Optional[str] = None - auto_schema_extraction: bool = False pdf_loader: Optional[ComponentType] = None kg_writer: Optional[ComponentType] = None @@ -170,7 +169,7 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]: Get the appropriate schema component based on configuration. Return SchemaFromTextExtractor for automatic extraction or SchemaBuilder for manual schema. """ - if self.auto_schema_extraction and not self.has_user_provided_schema(): + if not self.has_user_provided_schema(): return SchemaFromTextExtractor(llm=self.get_default_llm()) return SchemaBuilder() @@ -234,7 +233,7 @@ def _process_schema_with_precedence( return entities, relations, potential_schema def _get_run_params_for_schema(self) -> dict[str, Any]: - if self.auto_schema_extraction and not self.has_user_provided_schema(): + if not self.has_user_provided_schema(): # for automatic extraction, the text parameter is needed (will flow through the pipeline connections) return {} else: @@ -290,7 +289,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: ) # handle automatic schema extraction - if self.auto_schema_extraction and not self.has_user_provided_schema(): + if not self.has_user_provided_schema(): connections.append( ConnectionDefinition( start="pdf_loader", @@ -382,6 +381,6 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: ) run_params["splitter"] = {"text": text} # Add full text to schema component for automatic schema extraction - if self.auto_schema_extraction and not self.has_user_provided_schema(): + if not self.has_user_provided_schema(): run_params["schema"] = {"text": text} return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index c3159a681..c586a7fad 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -124,9 +124,6 @@ def __init__( perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, neo4j_database=neo4j_database, - auto_schema_extraction=not bool( - schema or entities or relations or potential_schema - ), ) except (ValidationError, ValueError) as e: raise PipelineDefinitionError() from e From 518343966cdac7d45bfed0066adbd111844aff6b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 6 May 2025 11:26:24 +0200 Subject: [PATCH 23/36] Fix unit tests --- .../test_simple_kg_builder.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index ef0365849..8aa318cd3 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -28,6 +28,7 @@ SchemaBuilder, SchemaEntity, SchemaRelation, + SchemaFromTextExtractor, ) from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, @@ -116,8 +117,21 @@ def test_simple_kg_pipeline_config_chunk_embedder( assert chunk_embedder._embedder == embedder -def test_simple_kg_pipeline_config_schema() -> None: +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +def test_simple_kg_pipeline_config_automatic_schema( + mock_llm: Mock, llm: LLMInterface +) -> None: + mock_llm.return_value = llm config = SimpleKGPipelineConfig() + schema = config._get_schema() + assert isinstance(schema, SchemaFromTextExtractor) + assert schema._llm == llm + + +def test_simple_kg_pipeline_config_manual_schema() -> None: + config = SimpleKGPipelineConfig(entities=["Person"]) assert isinstance(config._get_schema(), SchemaBuilder) @@ -205,9 +219,10 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 5 + assert len(connections) == 6 expected_connections = [ ("pdf_loader", "splitter"), + ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), @@ -240,9 +255,10 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None: perform_entity_resolution=True, ) connections = config._get_connections() - assert len(connections) == 6 + assert len(connections) == 7 expected_connections = [ ("pdf_loader", "splitter"), + ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), @@ -263,7 +279,8 @@ def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None: def test_simple_kg_pipeline_config_run_params_from_text_text() -> None: config = SimpleKGPipelineConfig(from_pdf=False) assert config.get_run_params({"text": "my text"}) == { - "splitter": {"text": "my text"} + "splitter": {"text": "my text"}, + "schema": {"text": "my text"}, } From d6b3491d2259cd2830a0d674cf0f71dc39e0dd77 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 6 May 2025 12:02:17 +0200 Subject: [PATCH 24/36] Handle cases where LLM outputs a valid JSON array --- .../experimental/components/schema.py | 5 ++ .../experimental/components/test_schema.py | 60 +++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 5b789564b..be8b61859 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -376,6 +376,11 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi try: extracted_schema: Dict[str, Any] = json.loads(content) + # handle the case where the LLM outputs a valid JSON array + if isinstance(extracted_schema, list) and len(extracted_schema) > 0: + # if the first item is a dict with expected schema components, use it + if isinstance(extracted_schema[0], dict): + extracted_schema = extracted_schema[0] except json.JSONDecodeError as exc: raise ValueError("LLM response is not valid JSON.") from exc diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 1ac91f154..c3077b50e 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -682,3 +682,63 @@ async def test_schema_config_from_file(schema_config: SchemaConfig) -> None: with pytest.raises(ValueError, match="Unsupported file format"): SchemaConfig.from_file(txt_path) + + +@pytest.fixture +def valid_schema_json_array() -> str: + return """ + [ + { + "entities": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"} + ] + }, + { + "label": "Organization", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "WORKS_FOR", + "properties": [ + {"name": "since", "type": "DATE"} + ] + } + ], + "potential_schema": [ + ["Person", "WORKS_FOR", "Organization"] + ] + } + ] + """ + + +@pytest.mark.asyncio +async def test_schema_from_text_run_valid_json_array( + schema_from_text: SchemaFromTextExtractor, + mock_llm: AsyncMock, + valid_schema_json_array: str, +) -> None: + # configure the mock LLM to return a valid JSON array + mock_llm.ainvoke.return_value = valid_schema_json_array + + # run the schema extraction + schema_config = await schema_from_text.run(text="Sample text for extraction") + + # verify the schema was correctly extracted from the array + assert len(schema_config.entities) == 2 + assert "Person" in schema_config.entities + assert "Organization" in schema_config.entities + + assert schema_config.relations is not None + assert "WORKS_FOR" in schema_config.relations + + assert schema_config.potential_schema is not None + assert len(schema_config.potential_schema) == 1 + assert schema_config.potential_schema[0] == ("Person", "WORKS_FOR", "Organization") From 3edf0d0618db00b2b513c47f1f641ea91121da2f Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 6 May 2025 14:45:29 +0200 Subject: [PATCH 25/36] Fix e2e tests --- .../experimental/test_simplekgpipeline_e2e.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/tests/e2e/experimental/test_simplekgpipeline_e2e.py b/tests/e2e/experimental/test_simplekgpipeline_e2e.py index d30ec3a66..3bd72dd41 100644 --- a/tests/e2e/experimental/test_simplekgpipeline_e2e.py +++ b/tests/e2e/experimental/test_simplekgpipeline_e2e.py @@ -181,6 +181,8 @@ async def test_pipeline_builder_two_documents( driver=driver, embedder=embedder, from_pdf=False, + # provide minimal schema to bypass automatic schema extraction + entities=["Person"], # in order to have 2 chunks: text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), ) @@ -261,6 +263,8 @@ async def test_pipeline_builder_same_document_two_runs( driver=driver, embedder=embedder, from_pdf=False, + # provide minimal schema to bypass automatic schema extraction + entities=["Person"], # in order to have 2 chunks: text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), ) @@ -280,3 +284,120 @@ async def test_pipeline_builder_same_document_two_runs( "MATCH (chunk:Chunk)<-[rel:FROM_CHUNK]-(entity:__Entity__) RETURN chunk, rel, entity" ) assert len(records) == 2 # two entities according to mocked LLMResponse + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_neo4j_for_kg_construction") +async def test_pipeline_builder_with_automatic_schema_extraction( + harry_potter_text_part1: str, + llm: MagicMock, + embedder: MagicMock, + driver: neo4j.Driver, +) -> None: + """Test pipeline with automatic schema extraction (no schema provided). + This test verifies that the pipeline correctly handles automatic schema extraction. + """ + driver.execute_query("MATCH (n) DETACH DELETE n") + embedder.embed_query.return_value = [1, 2, 3] + + # set up mock LLM responses for both schema extraction and entity extraction + llm.ainvoke.side_effect = [ + # first call - schema extraction response + LLMResponse( + content="""{ + "entities": [ + { + "label": "Person", + "description": "A character in the story", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "age", "type": "INTEGER"} + ] + }, + { + "label": "Location", + "description": "A place in the story", + "properties": [ + {"name": "name", "type": "STRING"} + ] + } + ], + "relations": [ + { + "label": "LOCATED_AT", + "description": "Indicates where a person is located", + "properties": [] + } + ], + "potential_schema": [ + ["Person", "LOCATED_AT", "Location"] + ] + }""" + ), + # second call - entity extraction for first chunk + LLMResponse( + content="""{ + "nodes": [ + { + "id": "0", + "label": "Person", + "properties": { + "name": "Harry Potter" + } + }, + { + "id": "1", + "label": "Location", + "properties": { + "name": "Hogwarts" + } + } + ], + "relationships": [ + { + "type": "LOCATED_AT", + "start_node_id": "0", + "end_node_id": "1" + } + ] + }""" + ), + # third call - entity extraction for second chunk (if text is split) + LLMResponse(content='{"nodes": [], "relationships": []}'), + ] + + # create an instance of the SimpleKGPipeline with NO schema provided + kg_builder_text = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=False, + # use smaller chunk size to ensure we have at least 2 chunks + text_splitter=FixedSizeSplitter(chunk_size=400, chunk_overlap=5), + ) + + # run the knowledge graph building process with text input + await kg_builder_text.run_async(text=harry_potter_text_part1) + + # verify LLM was called for schema extraction + assert llm.ainvoke.call_count >= 2 + + # verify entities were created + records, _, _ = driver.execute_query("MATCH (n:Person) RETURN n") + assert len(records) == 1 + + # verify locations were created + records, _, _ = driver.execute_query("MATCH (n:Location) RETURN n") + assert len(records) == 1 + + # verify relationships were created + records, _, _ = driver.execute_query( + "MATCH (p:Person)-[r:LOCATED_AT]->(l:Location) RETURN p, r, l" + ) + assert len(records) == 1 + + # verify chunks and relationships to entities + records, _, _ = driver.execute_query( + "MATCH (c:Chunk)<-[:FROM_CHUNK]-(e) RETURN c, e" + ) + assert len(records) >= 1 From 49c399c9daa658397951a268cb4993adbef0a13c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 6 May 2025 20:52:09 +0200 Subject: [PATCH 26/36] Address PR comments --- docs/source/user_guide_kg_builder.rst | 16 ++++------------ .../schema_from_text.py | 7 ------- .../experimental/components/schema.py | 13 +++++++++---- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index c263d5083..30d478667 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -128,15 +128,6 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated }, # ... ) - - # Using individual schema parameters (deprecated approach) - kg_builder = SimpleKGPipeline( - # ... - entities=ENTITIES, - relations=RELATIONS, - potential_schema=POTENTIAL_SCHEMA, - # ... - ) .. note:: By default, if no schema is provided to the SimpleKGPipeline, automatic schema extraction will be performed using the LLM (See the :ref:`Automatic Schema Extraction with SchemaFromTextExtractor`). @@ -481,8 +472,6 @@ or in YAML: neo4j_database: myDb on_error: IGNORE prompt_template: ... - - # Using the schema parameter (recommended approach) schema: entities: - Person @@ -823,7 +812,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto from neo4j_graphrag.experimental.components.schema import SchemaFromTextExtractor from neo4j_graphrag.llm import OpenAILLM - # Create the automatic schema extractor + # Instantiate the automatic schema extractor component schema_extractor = SchemaFromTextExtractor( llm=OpenAILLM( model_name="gpt-4o", @@ -834,6 +823,9 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto ) ) + # Extract the schema from the text + extracted_schema = await schema_extractor.run(text="Some text") + The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `SchemaConfig` object that can be used in the same way as a manually defined schema. You can also save and reload the extracted schema: diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/automatic_schema_extraction/schema_from_text.py index b20431fd2..e27a52224 100644 --- a/examples/automatic_schema_extraction/schema_from_text.py +++ b/examples/automatic_schema_extraction/schema_from_text.py @@ -25,13 +25,6 @@ logging.basicConfig() logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) -# Verify OpenAI API key is available -if not os.getenv("OPENAI_API_KEY"): - raise ValueError( - "OPENAI_API_KEY environment variable not found. " - "Please set it in the .env file in the root directory." - ) - # Sample text to extract schema from - it's about a company and its employees TEXT = """ Acme Corporation was founded in 1985 by John Smith in New York City. diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index be8b61859..b2ce0c856 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -376,11 +376,16 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi try: extracted_schema: Dict[str, Any] = json.loads(content) - # handle the case where the LLM outputs a valid JSON array - if isinstance(extracted_schema, list) and len(extracted_schema) > 0: - # if the first item is a dict with expected schema components, use it - if isinstance(extracted_schema[0], dict): + if not isinstance(extracted_schema, dict): + if ( + isinstance(extracted_schema, list) + and len(extracted_schema) > 0 + and isinstance(extracted_schema[0], dict) + ): extracted_schema = extracted_schema[0] + else: + # fallback to empty dict for any other case (e.g., empty list) + extracted_schema = {} except json.JSONDecodeError as exc: raise ValueError("LLM response is not valid JSON.") from exc From bf2fb9639492d34e0e143d957e888ead9ed3f333 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 7 May 2025 10:52:46 +0200 Subject: [PATCH 27/36] Add examples running SimpleKGPipeline --- .../simple_kg_pipeline_schema_from_pdf.py | 94 ++++++++++++++++ .../simple_kg_pipeline_schema_from_text.py | 102 ++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py create mode 100644 examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py diff --git a/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py b/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py new file mode 100644 index 000000000..448808e8e --- /dev/null +++ b/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py @@ -0,0 +1,94 @@ +"""This example demonstrates how to use SimpleKGPipeline with automatic schema extraction +from a PDF file. When no schema is provided to SimpleKGPipeline, automatic schema extraction +is performed using the LLM. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +import os +from dotenv import load_dotenv +import neo4j + +from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.embeddings import OpenAIEmbeddings + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) + +# PDF file path - replace with your own PDF file +DATA_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data" +) +PDF_FILE = os.path.join(DATA_DIR, "Harry Potter and the Death Hallows Summary.pdf") + + +async def run_kg_pipeline_with_auto_schema() -> None: + """Run the SimpleKGPipeline with automatic schema extraction from a PDF file.""" + + # Define Neo4j connection + uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") + user = os.getenv("NEO4J_USER", "neo4j") + password = os.getenv("NEO4J_PASSWORD", "password") + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Initialize the Neo4j driver + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + # Create the embedder instance + embedder = OpenAIEmbeddings() + + try: + # Create a SimpleKGPipeline instance without providing a schema + # This will trigger automatic schema extraction + kg_builder = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=True, + ) + + print(f"Processing PDF file: {PDF_FILE}") + # Run the pipeline on the PDF file + await kg_builder.run_async(file_path=PDF_FILE) + + finally: + # Close connections + await llm.async_client.close() + driver.close() + + +async def main() -> None: + """Run the example.""" + os.makedirs(DATA_DIR, exist_ok=True) + + # Check if the PDF file exists + if not os.path.exists(PDF_FILE): + print(f"Warning: PDF file not found at {PDF_FILE}") + print("Please replace with a valid PDF file path.") + return + + # Run the pipeline + await run_kg_pipeline_with_auto_schema() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py b/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py new file mode 100644 index 000000000..75b076306 --- /dev/null +++ b/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py @@ -0,0 +1,102 @@ +"""This example demonstrates how to use SimpleKGPipeline with automatic schema extraction +from a text input. When no schema is provided to SimpleKGPipeline, automatic schema extraction +is performed using the LLM. + +Note: This example requires an OpenAI API key to be set in the .env file. +""" + +import asyncio +import logging +import os +from dotenv import load_dotenv +import neo4j + +from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.embeddings import OpenAIEmbeddings + +# Load environment variables from .env file +load_dotenv() + +# Configure logging +logging.basicConfig() +logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG) + +# Sample text to build a knowledge graph from +TEXT = """ +Acme Corporation was founded in 1985 by John Smith in New York City. +The company specializes in manufacturing high-quality widgets and gadgets +for the consumer electronics industry. + +Sarah Johnson joined Acme in 2010 as a Senior Engineer and was promoted to +Engineering Director in 2015. She oversees a team of 12 engineers working on +next-generation products. Sarah holds a PhD in Electrical Engineering from MIT +and has filed 5 patents during her time at Acme. + +The company expanded to international markets in 2012, opening offices in London, +Tokyo, and Berlin. Each office is managed by a regional director who reports +directly to the CEO, Michael Brown, who took over leadership in 2008. + +Acme's most successful product, the SuperWidget X1, was launched in 2018 and +has sold over 2 million units worldwide. The product was developed by a team led +by Robert Chen, who joined the company in 2016 after working at TechGiant for 8 years. + +The company currently employs 250 people across its 4 locations and had a revenue +of $75 million in the last fiscal year. Acme is planning to go public in 2024 +with an estimated valuation of $500 million. +""" + + +async def run_kg_pipeline_with_auto_schema() -> None: + """Run the SimpleKGPipeline with automatic schema extraction from text input.""" + + # Define Neo4j connection + uri = os.getenv("NEO4J_URI", "neo4j://localhost:7687") + user = os.getenv("NEO4J_USER", "neo4j") + password = os.getenv("NEO4J_PASSWORD", "password") + + # Define LLM parameters + llm_model_params = { + "max_tokens": 2000, + "response_format": {"type": "json_object"}, + "temperature": 0, # Lower temperature for more consistent output + } + + # Initialize the Neo4j driver + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) + + # Create the LLM instance + llm = OpenAILLM( + model_name="gpt-4o", + model_params=llm_model_params, + ) + + # Create the embedder instance + embedder = OpenAIEmbeddings() + + try: + # Create a SimpleKGPipeline instance without providing a schema + # This will trigger automatic schema extraction + kg_builder = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + from_pdf=False, # Using raw text input, not PDF + ) + + # Run the pipeline on the text + await kg_builder.run_async(text=TEXT) + + finally: + # Close connections + await llm.async_client.close() + driver.close() + + +async def main() -> None: + """Run the example.""" + await run_kg_pipeline_with_auto_schema() + + +if __name__ == "__main__": + asyncio.run(main()) From ffea761ef57bdcc4bdc3f47184e361765f351273 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 12 May 2025 10:59:04 +0200 Subject: [PATCH 28/36] Add inferred schema json and yaml files example --- examples/data/extracted_schema.json | 137 ++++++++++++++++++++++++++++ examples/data/extracted_schema.yaml | 87 ++++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 examples/data/extracted_schema.json create mode 100644 examples/data/extracted_schema.yaml diff --git a/examples/data/extracted_schema.json b/examples/data/extracted_schema.json new file mode 100644 index 000000000..0ec66639b --- /dev/null +++ b/examples/data/extracted_schema.json @@ -0,0 +1,137 @@ +{ + "entities": { + "Company": { + "label": "Company", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "foundedYear", + "type": "INTEGER", + "description": "" + }, + { + "name": "revenue", + "type": "FLOAT", + "description": "" + }, + { + "name": "valuation", + "type": "FLOAT", + "description": "" + } + ] + }, + "Person": { + "label": "Person", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "position", + "type": "STRING", + "description": "" + }, + { + "name": "yearJoined", + "type": "INTEGER", + "description": "" + } + ] + }, + "Product": { + "label": "Product", + "description": "", + "properties": [ + { + "name": "name", + "type": "STRING", + "description": "" + }, + { + "name": "launchYear", + "type": "INTEGER", + "description": "" + }, + { + "name": "unitsSold", + "type": "INTEGER", + "description": "" + } + ] + }, + "Office": { + "label": "Office", + "description": "", + "properties": [ + { + "name": "location", + "type": "STRING", + "description": "" + } + ] + } + }, + "relations": { + "FOUNDED_BY": { + "label": "FOUNDED_BY", + "description": "", + "properties": [] + }, + "WORKS_FOR": { + "label": "WORKS_FOR", + "description": "", + "properties": [] + }, + "MANAGES": { + "label": "MANAGES", + "description": "", + "properties": [] + }, + "DEVELOPED_BY": { + "label": "DEVELOPED_BY", + "description": "", + "properties": [] + }, + "LOCATED_IN": { + "label": "LOCATED_IN", + "description": "", + "properties": [] + } + }, + "potential_schema": [ + [ + "Company", + "FOUNDED_BY", + "Person" + ], + [ + "Person", + "WORKS_FOR", + "Company" + ], + [ + "Person", + "MANAGES", + "Office" + ], + [ + "Product", + "DEVELOPED_BY", + "Person" + ], + [ + "Company", + "LOCATED_IN", + "Office" + ] + ] +} \ No newline at end of file diff --git a/examples/data/extracted_schema.yaml b/examples/data/extracted_schema.yaml new file mode 100644 index 000000000..f2500799f --- /dev/null +++ b/examples/data/extracted_schema.yaml @@ -0,0 +1,87 @@ +entities: + Company: + label: Company + description: '' + properties: + - name: name + type: STRING + description: '' + - name: foundedYear + type: INTEGER + description: '' + - name: revenue + type: FLOAT + description: '' + - name: valuation + type: FLOAT + description: '' + Person: + label: Person + description: '' + properties: + - name: name + type: STRING + description: '' + - name: position + type: STRING + description: '' + - name: yearJoined + type: INTEGER + description: '' + Product: + label: Product + description: '' + properties: + - name: name + type: STRING + description: '' + - name: launchYear + type: INTEGER + description: '' + - name: unitsSold + type: INTEGER + description: '' + Office: + label: Office + description: '' + properties: + - name: location + type: STRING + description: '' +relations: + FOUNDED_BY: + label: FOUNDED_BY + description: '' + properties: [] + WORKS_FOR: + label: WORKS_FOR + description: '' + properties: [] + MANAGES: + label: MANAGES + description: '' + properties: [] + DEVELOPED_BY: + label: DEVELOPED_BY + description: '' + properties: [] + LOCATED_IN: + label: LOCATED_IN + description: '' + properties: [] +potential_schema: +- - Company + - FOUNDED_BY + - Person +- - Person + - WORKS_FOR + - Company +- - Person + - MANAGES + - Office +- - Product + - DEVELOPED_BY + - Person +- - Company + - LOCATED_IN + - Office From 2ce0ff93bc5cae70326b64f4cde571e3d0e3aaa1 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 12 May 2025 11:54:10 +0200 Subject: [PATCH 29/36] Improve handling LLM response --- .../experimental/components/schema.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index b2ce0c856..8dfad5a2c 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -22,7 +22,7 @@ from pydantic import BaseModel, ValidationError, model_validator, validate_call from typing_extensions import Self -from neo4j_graphrag.exceptions import SchemaValidationError +from neo4j_graphrag.exceptions import SchemaValidationError, LLMGenerationError from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, @@ -367,12 +367,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi """ prompt: str = self._prompt_template.format(text=text, examples=examples) - response = await self._llm.ainvoke(prompt, **self._llm_params) - content: str = ( - response - if isinstance(response, str) - else getattr(response, "content", str(response)) - ) + try: + response = await self._llm.ainvoke(prompt, **self._llm_params) + content: str = response.content + except LLMGenerationError as e: + # Re-raise the LLMGenerationError + raise LLMGenerationError("Failed to generate schema from text") from e try: extracted_schema: Dict[str, Any] = json.loads(content) From f69eaceb565e682e2d1037a9493254def6fbc9ba Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 12 May 2025 12:13:33 +0200 Subject: [PATCH 30/36] Improve handling errors for extracted schema --- .../experimental/components/schema.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8dfad5a2c..b3bf2c424 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -16,6 +16,7 @@ import json import yaml +import logging from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pathlib import Path @@ -376,16 +377,28 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi try: extracted_schema: Dict[str, Any] = json.loads(content) - if not isinstance(extracted_schema, dict): - if ( - isinstance(extracted_schema, list) - and len(extracted_schema) > 0 - and isinstance(extracted_schema[0], dict) - ): + + # handle dictionary + if isinstance(extracted_schema, dict): + pass # Keep as is + # handle list + elif isinstance(extracted_schema, list): + if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict): extracted_schema = extracted_schema[0] - else: - # fallback to empty dict for any other case (e.g., empty list) + elif len(extracted_schema) == 0: + logging.warning( + "LLM returned an empty list for schema. Falling back to empty schema." + ) extracted_schema = {} + else: + raise ValueError( + f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}" + ) + # any other types + else: + raise ValueError( + f"Unexpected schema format returned from LLM: {type(extracted_schema)}. Expected a dictionary or list of dictionaries." + ) except json.JSONDecodeError as exc: raise ValueError("LLM response is not valid JSON.") from exc From 89b3d1b521be4bcf27f2834f00520859241143de Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 12 May 2025 12:23:40 +0200 Subject: [PATCH 31/36] Replace warning logs with real deprecation warnings --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index a5006d89d..8b6e752e6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -14,6 +14,7 @@ # limitations under the License. from typing import Any, ClassVar, Literal, Optional, Sequence, Union, List, Tuple import logging +import warnings from pydantic import ConfigDict, Field, model_validator from typing_extensions import Self @@ -113,18 +114,20 @@ def _process_schema_parameters(self) -> None: ) if has_individual_schema_components and self.schema_ is not None: - logger.warning( + warnings.warn( "Both 'schema' and individual schema components (entities, relations, potential_schema) " "were provided. The 'schema' parameter takes precedence. In the future, individual " "components will be removed. Please use only the 'schema' parameter.", + DeprecationWarning, stacklevel=2, ) elif has_individual_schema_components: - logger.warning( + warnings.warn( "The 'entities', 'relations', and 'potential_schema' parameters are deprecated " "and will be removed in a future version. " "Please use the 'schema' parameter instead.", + DeprecationWarning, stacklevel=2, ) From 83d90fbb7478fd637ad87138398b285e82aa49cc Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 12 May 2025 12:34:26 +0200 Subject: [PATCH 32/36] Fix schema unit tests --- tests/unit/experimental/components/test_schema.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index c3077b50e..f091e4fb0 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -33,6 +33,7 @@ import yaml from neo4j_graphrag.generation import PromptTemplate +from neo4j_graphrag.llm.types import LLMResponse @pytest.fixture @@ -509,7 +510,7 @@ async def test_schema_from_text_run_valid_response( valid_schema_json: str, ) -> None: # configure the mock LLM to return a valid schema JSON - mock_llm.ainvoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) # run the schema extraction schema_config = await schema_from_text.run(text="Sample text for extraction") @@ -540,7 +541,7 @@ async def test_schema_from_text_run_invalid_json( invalid_schema_json: str, ) -> None: # configure the mock LLM to return invalid JSON - mock_llm.ainvoke.return_value = invalid_schema_json + mock_llm.ainvoke.return_value = LLMResponse(content=invalid_schema_json) # verify that running with invalid JSON raises a ValueError with pytest.raises(ValueError) as exc_info: @@ -563,7 +564,7 @@ async def test_schema_from_text_custom_template( ) # configure mock LLM to return valid JSON and capture the prompt that was sent to it - mock_llm.ainvoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) # run the schema extraction await schema_from_text.run(text="Sample text") @@ -584,7 +585,7 @@ async def test_schema_from_text_llm_params( schema_from_text = SchemaFromTextExtractor(llm=mock_llm, llm_params=llm_params) # configure the mock LLM to return a valid schema JSON - mock_llm.ainvoke.return_value = valid_schema_json + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json) # run the schema extraction await schema_from_text.run(text="Sample text") @@ -726,7 +727,7 @@ async def test_schema_from_text_run_valid_json_array( valid_schema_json_array: str, ) -> None: # configure the mock LLM to return a valid JSON array - mock_llm.ainvoke.return_value = valid_schema_json_array + mock_llm.ainvoke.return_value = LLMResponse(content=valid_schema_json_array) # run the schema extraction schema_config = await schema_from_text.run(text="Sample text for extraction") From 29aec54b460e5fb9b1471dd25c2285506d80d751 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 13 May 2025 10:01:45 +0200 Subject: [PATCH 33/36] Ensure proper handling of schema when provided as dict --- .../experimental/components/schema.py | 28 +++++++++++++++++ .../template_pipeline/simple_kg_builder.py | 31 ++++++++++++++++--- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index b3bf2c424..e2115f563 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -424,3 +424,31 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi relations=relations, potential_schema=potential_schema, ) + + +def normalize_schema_dict(schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize a user-provided schema dictionary to the canonical format expected by the pipeline. + + - Converts 'entities' and 'relations' from lists (of strings, dicts, or model objects) to dicts keyed by label. + - Ensures required keys ('entities', 'relations', 'potential_schema') are present. + - Does not mutate the input; returns a new dict. + + Args: + schema_dict (dict): The user-provided schema dictionary, possibly with lists or missing keys. + + Returns: + dict: A normalized schema dictionary with the correct structure for pipeline and Pydantic validation. + """ + norm_schema_dict = dict(schema_dict) + for key, cls in [("entities", SchemaEntity), ("relations", SchemaRelation)]: + if key in norm_schema_dict and isinstance(norm_schema_dict[key], list): + norm_schema_dict[key] = { + cls.from_text_or_dict(e).label: cls.from_text_or_dict(e).model_dump() # type: ignore[attr-defined] + for e in norm_schema_dict[key] + } + if "relations" not in norm_schema_dict: + norm_schema_dict["relations"] = {} + if "potential_schema" not in norm_schema_dict: + norm_schema_dict["potential_schema"] = None + return norm_schema_dict diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 8b6e752e6..0df0f61e6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -12,7 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ClassVar, Literal, Optional, Sequence, Union, List, Tuple +from typing import ( + Any, + ClassVar, + Literal, + Optional, + Sequence, + Union, + List, + Tuple, + Dict, + cast, +) import logging import warnings @@ -37,6 +48,7 @@ SchemaEntity, SchemaRelation, SchemaFromTextExtractor, + normalize_schema_dict, ) from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( @@ -97,6 +109,14 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): model_config = ConfigDict(arbitrary_types_allowed=True) + @model_validator(mode="before") + def normalize_schema_field(cls, data: Dict[str, Any]) -> Dict[str, Any]: + # Normalize the 'schema' field if it is a dict + schema = data.get("schema") + if isinstance(schema, dict): + data["schema"] = normalize_schema_dict(schema) + return data + @model_validator(mode="after") def handle_schema_precedence(self) -> Self: """Handle schema precedence and warnings""" @@ -209,14 +229,17 @@ def _process_schema_with_precedence( potential_schema = self.schema_.potential_schema else: - # extract from dictionary entities = [ SchemaEntity.from_text_or_dict(e) - for e in self.schema_.get("entities", []) + for e in cast( + Dict[str, Any], self.schema_.get("entities", {}) + ).values() ] relations = [ SchemaRelation.from_text_or_dict(r) - for r in self.schema_.get("relations", []) + for r in cast( + Dict[str, Any], self.schema_.get("relations", {}) + ).values() ] potential_schema = self.schema_.get("potential_schema") else: From 4e6d53aa4170519f54998edd105a13c08f19889c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 13 May 2025 12:42:16 +0200 Subject: [PATCH 34/36] Move example files to the right directories --- .../simple_kg_builder_schema_from_pdf.py} | 15 +++++++++------ .../simple_kg_builder_schema_from_text.py} | 0 .../schema_builders}/schema_from_text.py | 13 ++++++------- 3 files changed, 15 insertions(+), 13 deletions(-) rename examples/{automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py => build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py} (88%) rename examples/{automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py => build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py} (100%) rename examples/{automatic_schema_extraction => customize/build_graph/components/schema_builders}/schema_from_text.py (93%) diff --git a/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py similarity index 88% rename from examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py rename to examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py index 448808e8e..639f0b93d 100644 --- a/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_pdf.py +++ b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_pdf.py @@ -8,6 +8,7 @@ import asyncio import logging import os +from pathlib import Path from dotenv import load_dotenv import neo4j @@ -22,11 +23,11 @@ logging.basicConfig() logging.getLogger("neo4j_graphrag").setLevel(logging.INFO) -# PDF file path - replace with your own PDF file -DATA_DIR = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data" +# PDF file path +root_dir = Path(__file__).parents[2] +PDF_FILE = str( + root_dir / "data" / "Harry Potter and the Chamber of Secrets Summary.pdf" ) -PDF_FILE = os.path.join(DATA_DIR, "Harry Potter and the Death Hallows Summary.pdf") async def run_kg_pipeline_with_auto_schema() -> None: @@ -78,10 +79,12 @@ async def run_kg_pipeline_with_auto_schema() -> None: async def main() -> None: """Run the example.""" - os.makedirs(DATA_DIR, exist_ok=True) + # Create data directory if it doesn't exist + data_dir = root_dir / "data" + data_dir.mkdir(exist_ok=True) # Check if the PDF file exists - if not os.path.exists(PDF_FILE): + if not Path(PDF_FILE).exists(): print(f"Warning: PDF file not found at {PDF_FILE}") print("Please replace with a valid PDF file path.") return diff --git a/examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py b/examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py similarity index 100% rename from examples/automatic_schema_extraction/simple_kg_pipeline_schema_from_text.py rename to examples/build_graph/automatic_schema_extraction/simple_kg_builder_schema_from_text.py diff --git a/examples/automatic_schema_extraction/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py similarity index 93% rename from examples/automatic_schema_extraction/schema_from_text.py rename to examples/customize/build_graph/components/schema_builders/schema_from_text.py index e27a52224..4396f3fb6 100644 --- a/examples/automatic_schema_extraction/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -9,7 +9,7 @@ import asyncio import logging -import os +from pathlib import Path from dotenv import load_dotenv from neo4j_graphrag.experimental.components.schema import ( @@ -50,11 +50,10 @@ """ # Define the file paths for saving the schema -OUTPUT_DIR = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data" -) -JSON_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.json") -YAML_FILE_PATH = os.path.join(OUTPUT_DIR, "extracted_schema.yaml") +root_dir = Path(__file__).parents[4] +OUTPUT_DIR = str(root_dir / "data") +JSON_FILE_PATH = str(root_dir / "data" / "extracted_schema.json") +YAML_FILE_PATH = str(root_dir / "data" / "extracted_schema.yaml") async def extract_and_save_schema() -> None: @@ -82,7 +81,7 @@ async def extract_and_save_schema() -> None: inferred_schema = await schema_extractor.run(text=TEXT) # Ensure the output directory exists - os.makedirs(OUTPUT_DIR, exist_ok=True) + Path(OUTPUT_DIR).mkdir(exist_ok=True) print(f"Saving schema to JSON file: {JSON_FILE_PATH}") # Save the schema to JSON file From 48ec9b71c8b04a5db10278b55dfbc151907b092b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 13 May 2025 12:59:28 +0200 Subject: [PATCH 35/36] Add custom schema extraction error --- src/neo4j_graphrag/exceptions.py | 6 ++++++ src/neo4j_graphrag/experimental/components/schema.py | 12 ++++++++---- tests/unit/experimental/components/test_schema.py | 4 ++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 3c0fdc0b3..681b20eec 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -116,6 +116,12 @@ class SchemaValidationError(Neo4jGraphRagError): pass +class SchemaExtractionError(Neo4jGraphRagError): + """Exception raised for errors in automatic schema extraction.""" + + pass + + class PdfLoaderError(Neo4jGraphRagError): """Custom exception for errors in PDF loader.""" diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index e2115f563..d9705b921 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -23,7 +23,11 @@ from pydantic import BaseModel, ValidationError, model_validator, validate_call from typing_extensions import Self -from neo4j_graphrag.exceptions import SchemaValidationError, LLMGenerationError +from neo4j_graphrag.exceptions import ( + SchemaValidationError, + LLMGenerationError, + SchemaExtractionError, +) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, @@ -391,16 +395,16 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi ) extracted_schema = {} else: - raise ValueError( + raise SchemaExtractionError( f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}" ) # any other types else: - raise ValueError( + raise SchemaExtractionError( f"Unexpected schema format returned from LLM: {type(extracted_schema)}. Expected a dictionary or list of dictionaries." ) except json.JSONDecodeError as exc: - raise ValueError("LLM response is not valid JSON.") from exc + raise SchemaExtractionError("LLM response is not valid JSON.") from exc extracted_entities: List[Dict[str, Any]] = ( extracted_schema.get("entities") or [] diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index f091e4fb0..be7bbd958 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -18,7 +18,7 @@ from unittest.mock import AsyncMock import pytest -from neo4j_graphrag.exceptions import SchemaValidationError +from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, SchemaEntity, @@ -544,7 +544,7 @@ async def test_schema_from_text_run_invalid_json( mock_llm.ainvoke.return_value = LLMResponse(content=invalid_schema_json) # verify that running with invalid JSON raises a ValueError - with pytest.raises(ValueError) as exc_info: + with pytest.raises(SchemaExtractionError) as exc_info: await schema_from_text.run(text="Sample text for extraction") assert "not valid JSON" in str(exc_info.value) From 44e76def7a71bbe4be07fbb3d21ab30b7f5dae09 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 13 May 2025 16:15:08 +0200 Subject: [PATCH 36/36] Handle invalid format for extracted schema --- .../experimental/components/schema.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index d9705b921..a58b0b105 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -416,12 +416,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> SchemaConfi "potential_schema" ) - entities: List[SchemaEntity] = [SchemaEntity(**e) for e in extracted_entities] - relations: Optional[List[SchemaRelation]] = ( - [SchemaRelation(**r) for r in extracted_relations] - if extracted_relations - else None - ) + try: + entities: List[SchemaEntity] = [ + SchemaEntity(**e) for e in extracted_entities + ] + relations: Optional[List[SchemaRelation]] = ( + [SchemaRelation(**r) for r in extracted_relations] + if extracted_relations + else None + ) + except ValidationError as exc: + raise SchemaValidationError( + f"Invalid schema format return from LLM: {exc}" + ) from exc return SchemaBuilder.create_schema_model( entities=entities,