From 4a3d8aa62e5e25f8f43d26a7c463c03290c66d46 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 12:52:15 +0200 Subject: [PATCH 1/7] Fix schema validation --- .../components/entity_relation_extractor.py | 6 ++- .../experimental/components/schema.py | 22 +++++----- .../template_pipeline/simple_kg_builder.py | 41 ++++++------------- src/neo4j_graphrag/generation/prompts.py | 2 +- .../test_simple_kg_builder.py | 6 +-- 5 files changed, 31 insertions(+), 46 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 1d29fefe2..32d65276e 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -213,7 +213,9 @@ async def extract_for_chunk( ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( - text=chunk.text, schema=schema.model_dump(), examples=examples + text=chunk.text, + schema=schema.model_dump(exclude_none=True), + examples=examples, ) llm_result = await self.llm.ainvoke(prompt) try: @@ -326,7 +328,7 @@ async def run( elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or GraphSchema( - node_types=(), relationship_types=(), patterns=() + node_types=(), relationship_types=None, patterns=None ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 1136c6db3..374e6785f 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -81,13 +81,12 @@ class NodeType(BaseModel): description: str = "" properties: list[PropertyType] = [] + @model_validator(mode="before") @classmethod - def from_text_or_dict(cls, input: EntityInputType) -> Self: - if isinstance(input, NodeType): - return input - if isinstance(input, str): - return cls(label=input) - return cls.model_validate(input) + def validate(cls, data: EntityInputType) -> Self: + if isinstance(data, str): + return {"label": data} + return data class RelationshipType(BaseModel): @@ -99,13 +98,12 @@ class RelationshipType(BaseModel): description: str = "" properties: list[PropertyType] = [] + @model_validator(mode="before") @classmethod - def from_text_or_dict(cls, input: RelationInputType) -> Self: - if isinstance(input, RelationshipType): - return input - if isinstance(input, str): - return cls(label=input) - return cls.model_validate(input) + def validate(cls, data: EntityInputType) -> Self: + if isinstance(data, str): + return {"label": data} + return data class GraphSchema(DataModel): diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index d8ea45fe4..933231e7d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -89,9 +89,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - schema_: Optional[Union[GraphSchema, dict[str, list[Any]]]] = Field( - default=None, alias="schema" - ) + schema_: Optional[GraphSchema] = Field(default=None, alias="schema") enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() @@ -202,39 +200,26 @@ def _process_schema_with_precedence( """ if self.schema_ is not None: # schema takes precedence over individual components - if isinstance(self.schema_, GraphSchema): - # extract components from GraphSchema - node_types = self.schema_.node_types - - # handle case where relations could be None - if self.schema_.relationship_types is not None: - relationship_types = self.schema_.relationship_types - else: - relationship_types = () + node_types = self.schema_.node_types - patterns = self.schema_.patterns + # handle case where relations could be None + if self.schema_.relationship_types is not None: + relationship_types = self.schema_.relationship_types else: - node_types = tuple( - NodeType.from_text_or_dict(e) - for e in self.schema_.get("node_types", ()) - ) - relationship_types = tuple( - RelationshipType.from_text_or_dict(r) - for r in self.schema_.get("relationship_types", ()) - ) - ps = self.schema_.get("patterns") - patterns = tuple(ps) if ps else None + relationship_types = None + + patterns = self.schema_.patterns else: # use individual components node_types = tuple( - [NodeType.from_text_or_dict(e) for e in self.entities] + [NodeType.model_validate(e) for e in self.entities] if self.entities else [] ) - relationship_types = tuple( - [RelationshipType.from_text_or_dict(r) for r in self.relations] - if self.relations - else [] + relationship_types = ( + tuple([RelationshipType.model_validate(r) for r in self.relations]) + if self.relations is not None + else None ) patterns = tuple(self.potential_schema) if self.potential_schema else None diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 24de870fb..8c2bc470e 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -171,7 +171,7 @@ class ERExtractionTemplate(PromptTemplate): {{"nodes": [ {{"id": "0", "label": "Person", "properties": {{"name": "John"}} }}], "relationships": [{{"type": "KNOWS", "start_node_id": "0", "end_node_id": "1", "properties": {{"since": "2024-08-01"}} }}] }} -Use only the following nodes and relationships (if provided): +Use only the following node and relationship types (if provided): {schema} Assign a unique ID (string) to each node, and reuse it to define relationships. diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 8bcdd1f37..34157f494 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -408,7 +408,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_dict() ("Person", "CREATED", "Organization"), ] config = SimpleKGPipelineConfig( - schema={ + schema={ # type: ignore "node_types": entities, "relationship_types": relations, "patterns": potential_schema, @@ -433,7 +433,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object( None ): entities = [ - {"label": "Person"}, + "Person", { "label": "Organization", "description": "A group of persons", @@ -446,7 +446,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object( }, ] relations = [ - {"label": "WORKS_FOR"}, + "WORKS_FOR", { "label": "CREATED", "description": "A person created an organization", From 8d98a566937d7b1e0bc15935b73ca2713a1fe4ac Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 12:56:40 +0200 Subject: [PATCH 2/7] Ruff --- src/neo4j_graphrag/experimental/components/schema.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 374e6785f..63ba3fb24 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -37,8 +37,7 @@ ) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.schema import ( - EntityInputType, - RelationInputType, + EntityInputType, RelationInputType, ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface @@ -100,7 +99,7 @@ class RelationshipType(BaseModel): @model_validator(mode="before") @classmethod - def validate(cls, data: EntityInputType) -> Self: + def validate(cls, data: RelationInputType) -> Self: if isinstance(data, str): return {"label": data} return data From 26506af8254fc9c0baa528f191095adb25e5cdb8 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 12:59:14 +0200 Subject: [PATCH 3/7] Ruff 2 --- src/neo4j_graphrag/experimental/components/schema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 63ba3fb24..e678c32a4 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -37,7 +37,8 @@ ) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.schema import ( - EntityInputType, RelationInputType, + EntityInputType, + RelationInputType, ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface From c331d13054506d4095cc437656f6db3262bbba57 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 13:16:41 +0200 Subject: [PATCH 4/7] Mypy --- .../experimental/components/schema.py | 4 +- .../template_pipeline/simple_kg_builder.py | 4 +- .../experimental/pipeline/kg_builder.py | 41 ++++++++++--------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index e678c32a4..a4af30e5d 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -83,7 +83,7 @@ class NodeType(BaseModel): @model_validator(mode="before") @classmethod - def validate(cls, data: EntityInputType) -> Self: + def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: if isinstance(data, str): return {"label": data} return data @@ -100,7 +100,7 @@ class RelationshipType(BaseModel): @model_validator(mode="before") @classmethod - def validate(cls, data: RelationInputType) -> Self: + def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: if isinstance(data, str): return {"label": data} return data diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 933231e7d..030ec13fb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -186,8 +186,8 @@ def _process_schema_with_precedence( self, ) -> Tuple[ Tuple[NodeType, ...], - Tuple[RelationshipType, ...], - Optional[Tuple[Tuple[str, str, str], ...]], + Tuple[RelationshipType, ...] | None, + Optional[Tuple[Tuple[str, str, str], ...]] | None, ]: """ Process schema inputs according to precedence rules: diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index d46ddc046..b0231e50f 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -105,25 +105,28 @@ def __init__( neo4j_database: Optional[str] = None, ): try: - config = SimpleKGPipelineConfig( - # argument type are fixed in the Config object - llm_config=llm, # type: ignore[arg-type] - neo4j_config=driver, # type: ignore[arg-type] - embedder_config=embedder, # type: ignore[arg-type] - entities=entities or [], - relations=relations or [], - potential_schema=potential_schema, - schema=schema, - enforce_schema=SchemaEnforcementMode(enforce_schema), - from_pdf=from_pdf, - pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, - kg_writer=ComponentType(kg_writer) if kg_writer else None, - text_splitter=ComponentType(text_splitter) if text_splitter else None, - on_error=OnError(on_error), - prompt_template=prompt_template, - perform_entity_resolution=perform_entity_resolution, - lexical_graph_config=lexical_graph_config, - neo4j_database=neo4j_database, + config = SimpleKGPipelineConfig.model_validate( + dict( + llm_config=llm, + neo4j_config=driver, + embedder_config=embedder, + entities=entities or [], + relations=relations or [], + potential_schema=potential_schema, + schema=schema, + enforce_schema=SchemaEnforcementMode(enforce_schema), + from_pdf=from_pdf, + pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, + kg_writer=ComponentType(kg_writer) if kg_writer else None, + text_splitter=ComponentType(text_splitter) + if text_splitter + else None, + on_error=OnError(on_error), + prompt_template=prompt_template, + perform_entity_resolution=perform_entity_resolution, + lexical_graph_config=lexical_graph_config, + neo4j_database=neo4j_database, + ) ) except (ValidationError, ValueError) as e: raise PipelineDefinitionError() from e From 8337dc4b88cc8ec16cc50db16c318e1084802abb Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 13:36:17 +0200 Subject: [PATCH 5/7] Mypy 2 --- .../config/template_pipeline/test_simple_kg_builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 34157f494..73bb50eeb 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -365,6 +365,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_legacy() -> No assert len(node_types[0].properties) == 0 assert node_types[1].label == "Organization" assert len(node_types[1].properties) == 1 + assert relationship_types is not None assert len(relationship_types) == 2 assert relationship_types[0].label == "WORKS_FOR" assert len(relationship_types[0].properties) == 0 @@ -420,6 +421,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_dict() assert len(node_types[0].properties) == 0 assert node_types[1].label == "Organization" assert len(node_types[1].properties) == 1 + assert relationship_types is not None assert len(relationship_types) == 2 assert relationship_types[0].label == "WORKS_FOR" assert len(relationship_types[0].properties) == 0 @@ -479,6 +481,7 @@ def test_simple_kg_pipeline_config_process_schema_with_precedence_schema_object( assert len(node_types[0].properties) == 0 assert node_types[1].label == "Organization" assert len(node_types[1].properties) == 1 + assert relationship_types is not None assert len(relationship_types) == 2 assert relationship_types[0].label == "WORKS_FOR" assert len(relationship_types[0].properties) == 0 From e4d4a31a2cf5b8da25dda939dca7c1bc446853c1 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 13:46:33 +0200 Subject: [PATCH 6/7] Test fix --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 030ec13fb..4aee855bb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import ( Any, ClassVar, From 94634b80e1af84f04a973c9f281951001a541171 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 14:07:10 +0200 Subject: [PATCH 7/7] Update the config --- .../simple_kg_pipeline_config.json | 112 +++++++++--------- .../simple_kg_pipeline_config.yaml | 51 ++++---- .../simple_kg_pipeline_config_url.json | 112 +++++++++--------- 3 files changed, 140 insertions(+), 135 deletions(-) diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.json b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json index ef2516245..b22119a21 100644 --- a/examples/build_graph/from_config_files/simple_kg_pipeline_config.json +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.json @@ -42,65 +42,67 @@ } }, "from_pdf": false, - "entities": [ - "Person", - { - "label": "House", - "description": "Family the person belongs to", - "properties": [ - { - "name": "name", - "type": "STRING" - } - ] - }, - { - "label": "Planet", - "properties": [ - { - "name": "name", - "type": "STRING" - }, - { - "name": "weather", - "type": "STRING" - } - ] - } - ], - "relations": [ - "PARENT_OF", - { - "label": "HEIR_OF", - "description": "Used for inheritor relationship between father and sons" - }, - { - "label": "RULES", - "properties": [ - { - "name": "fromYear", - "type": "INTEGER" - } - ] - } - ], - "potential_schema": [ - [ + "schema": { + "node_types": [ "Person", - "PARENT_OF", - "Person" + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + { + "name": "name", + "type": "STRING" + } + ] + }, + { + "label": "Planet", + "properties": [ + { + "name": "name", + "type": "STRING" + }, + { + "name": "weather", + "type": "STRING" + } + ] + } ], - [ - "Person", - "HEIR_OF", - "House" + "relationship_types": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + { + "name": "fromYear", + "type": "INTEGER" + } + ] + } ], - [ - "House", - "RULES", - "Planet" + "patterns": [ + [ + "Person", + "PARENT_OF", + "Person" + ], + [ + "Person", + "HEIR_OF", + "House" + ], + [ + "House", + "RULES", + "Planet" + ] ] - ], + }, "text_splitter": { "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter", "params_": { diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml index 8917e8ca3..051c47d74 100644 --- a/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config.yaml @@ -30,31 +30,32 @@ embedder_config: resolver_: ENV var_: OPENAI_API_KEY from_pdf: false -entities: - - label: Person - - label: House - description: Family the person belongs to - properties: - - name: name - type: STRING - - label: Planet - properties: - - name: name - type: STRING - - name: weather - type: STRING -relations: - - label: PARENT_OF - - label: HEIR_OF - description: Used for inheritor relationship between father and sons - - label: RULES - properties: - - name: fromYear - type: INTEGER -potential_schema: - - ["Person", "PARENT_OF", "Person"] - - ["Person", "HEIR_OF", "House"] - - ["House", "RULES", "Planet"] +schema: + node_types: + - label: Person + - label: House + description: Family the person belongs to + properties: + - name: name + type: STRING + - label: Planet + properties: + - name: name + type: STRING + - name: weather + type: STRING + relationship_types: + - label: PARENT_OF + - label: HEIR_OF + description: Used for inheritor relationship between father and sons + - label: RULES + properties: + - name: fromYear + type: INTEGER + patterns: + - ["Person", "PARENT_OF", "Person"] + - ["Person", "HEIR_OF", "House"] + - ["House", "RULES", "Planet"] text_splitter: class_: text_splitters.fixed_size_splitter.FixedSizeSplitter params_: diff --git a/examples/build_graph/from_config_files/simple_kg_pipeline_config_url.json b/examples/build_graph/from_config_files/simple_kg_pipeline_config_url.json index 12cb22379..524f50955 100644 --- a/examples/build_graph/from_config_files/simple_kg_pipeline_config_url.json +++ b/examples/build_graph/from_config_files/simple_kg_pipeline_config_url.json @@ -42,65 +42,67 @@ } }, "from_pdf": true, - "entities": [ - "Person", - { - "label": "House", - "description": "Family the person belongs to", - "properties": [ - { - "name": "name", - "type": "STRING" - } - ] - }, - { - "label": "Planet", - "properties": [ - { - "name": "name", - "type": "STRING" - }, - { - "name": "weather", - "type": "STRING" - } - ] - } - ], - "relations": [ - "PARENT_OF", - { - "label": "HEIR_OF", - "description": "Used for inheritor relationship between father and sons" - }, - { - "label": "RULES", - "properties": [ - { - "name": "fromYear", - "type": "INTEGER" - } - ] - } - ], - "potential_schema": [ - [ + "schema": { + "node_types": [ "Person", - "PARENT_OF", - "Person" + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + { + "name": "name", + "type": "STRING" + } + ] + }, + { + "label": "Planet", + "properties": [ + { + "name": "name", + "type": "STRING" + }, + { + "name": "weather", + "type": "STRING" + } + ] + } ], - [ - "Person", - "HEIR_OF", - "House" + "relationship_types": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + { + "name": "fromYear", + "type": "INTEGER" + } + ] + } ], - [ - "House", - "RULES", - "Planet" + "patterns": [ + [ + "Person", + "PARENT_OF", + "Person" + ], + [ + "Person", + "HEIR_OF", + "House" + ], + [ + "House", + "RULES", + "Planet" + ] ] - ], + }, "text_splitter": { "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter", "params_": {