Skip to content

Commit 2b14541

Browse files
Update SimpleKGPipeline for automatic schema extraction
1 parent baf9302 commit 2b14541

File tree

2 files changed

+110
-15
lines changed

2 files changed

+110
-15
lines changed

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
SchemaBuilder,
3333
SchemaEntity,
3434
SchemaRelation,
35+
SchemaFromText,
3536
)
3637
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
3738
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
@@ -80,13 +81,18 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8081
perform_entity_resolution: bool = True
8182
lexical_graph_config: Optional[LexicalGraphConfig] = None
8283
neo4j_database: Optional[str] = None
84+
auto_schema_extraction: bool = False
8385

8486
pdf_loader: Optional[ComponentType] = None
8587
kg_writer: Optional[ComponentType] = None
8688
text_splitter: Optional[ComponentType] = None
8789

8890
model_config = ConfigDict(arbitrary_types_allowed=True)
8991

92+
def has_user_provided_schema(self) -> bool:
93+
"""Check if the user has provided schema information"""
94+
return bool(self.entities or self.relations or self.potential_schema)
95+
9096
def _get_pdf_loader(self) -> Optional[PdfLoader]:
9197
if not self.from_pdf:
9298
return None
@@ -114,15 +120,26 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
114120
def _get_chunk_embedder(self) -> TextChunkEmbedder:
115121
return TextChunkEmbedder(embedder=self.get_default_embedder())
116122

117-
def _get_schema(self) -> SchemaBuilder:
123+
def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]:
124+
"""
125+
Get the appropriate schema component based on configuration.
126+
Return SchemaFromText for automatic extraction or SchemaBuilder for manual schema.
127+
"""
128+
if self.auto_schema_extraction and not self.has_user_provided_schema():
129+
return SchemaFromText(llm=self.get_default_llm())
118130
return SchemaBuilder()
119131

120132
def _get_run_params_for_schema(self) -> dict[str, Any]:
121-
return {
122-
"entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities],
123-
"relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations],
124-
"potential_schema": self.potential_schema,
125-
}
133+
if self.auto_schema_extraction and not self.has_user_provided_schema():
134+
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
135+
return {}
136+
else:
137+
# for manual schema, use the provided entities/relations/potential_schema
138+
return {
139+
"entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities],
140+
"relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations],
141+
"potential_schema": self.potential_schema,
142+
}
126143

127144
def _get_extractor(self) -> EntityRelationExtractor:
128145
return LLMEntityRelationExtractor(
@@ -163,6 +180,17 @@ def _get_connections(self) -> list[ConnectionDefinition]:
163180
input_config={"text": "pdf_loader.text"},
164181
)
165182
)
183+
184+
# handle automatic schema extraction
185+
if self.auto_schema_extraction and not self.has_user_provided_schema():
186+
connections.append(
187+
ConnectionDefinition(
188+
start="pdf_loader",
189+
end="schema",
190+
input_config={"text": "pdf_loader.text"},
191+
)
192+
)
193+
166194
connections.append(
167195
ConnectionDefinition(
168196
start="schema",
@@ -174,13 +202,21 @@ def _get_connections(self) -> list[ConnectionDefinition]:
174202
)
175203
)
176204
else:
205+
# handle automatic schema extraction for direct text input: ensure schema extraction uses the complete text
206+
if self.auto_schema_extraction and not self.has_user_provided_schema():
207+
connections.append(
208+
ConnectionDefinition(
209+
start="__input__", # connection to pipeline input
210+
end="schema",
211+
input_config={"text": "text"}, # use the original text input
212+
)
213+
)
214+
177215
connections.append(
178216
ConnectionDefinition(
179217
start="schema",
180218
end="extractor",
181-
input_config={
182-
"schema": "schema",
183-
},
219+
input_config={"schema": "schema"},
184220
)
185221
)
186222
connections.append(

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from __future__ import annotations
1717

1818
from typing import List, Optional, Sequence, Union
19+
import logging
20+
import warnings
1921

2022
import neo4j
2123
from pydantic import ValidationError
@@ -42,7 +44,9 @@
4244
)
4345
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
4446
from neo4j_graphrag.llm.base import LLMInterface
47+
from neo4j_graphrag.experimental.components.schema import SchemaConfig, SchemaBuilder
4548

49+
logger = logging.getLogger(__name__)
4650

4751
class SimpleKGPipeline:
4852
"""
@@ -53,17 +57,20 @@ class SimpleKGPipeline:
5357
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
5458
driver (neo4j.Driver): A Neo4j driver instance for database connection.
5559
embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks.
56-
entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): A list of either:
60+
schema (Optional[Union[SchemaConfig, dict[str, list]]]): A schema configuration defining entities,
61+
relations, and potential schema relationships.
62+
This is the recommended way to provide schema information.
63+
entities (Optional[List[Union[str, dict[str, str], SchemaEntity]]]): DEPRECATED. A list of either:
5764
5865
- str: entity labels
5966
- dict: following the SchemaEntity schema, ie with label, description and properties keys
6067
61-
relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): A list of either:
68+
relations (Optional[List[Union[str, dict[str, str], SchemaRelation]]]): DEPRECATED. A list of either:
6269
6370
- str: relation label
6471
- dict: following the SchemaRelation schema, ie with label, description and properties keys
6572
66-
potential_schema (Optional[List[tuple]]): A list of potential schema relationships.
73+
potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships.
6774
enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT".
6875
from_pdf (bool): Determines whether to include the PdfLoader in the pipeline.
6976
If True, expects `file_path` input in `run` methods.
@@ -85,6 +92,7 @@ def __init__(
8592
entities: Optional[Sequence[EntityInputType]] = None,
8693
relations: Optional[Sequence[RelationInputType]] = None,
8794
potential_schema: Optional[List[tuple[str, str, str]]] = None,
95+
schema: Optional[Union[SchemaConfig, dict[str, list]]] = None,
8896
enforce_schema: str = "NONE",
8997
from_pdf: bool = True,
9098
text_splitter: Optional[TextSplitter] = None,
@@ -96,15 +104,65 @@ def __init__(
96104
lexical_graph_config: Optional[LexicalGraphConfig] = None,
97105
neo4j_database: Optional[str] = None,
98106
):
107+
# deprecation warnings for old parameters
108+
if any([entities, relations, potential_schema]) and schema is not None:
109+
logger.warning(
110+
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
111+
"were provided. The 'schema' parameter takes precedence. In the future, individual "
112+
"components will be removed. Please use only the 'schema' parameter."
113+
)
114+
# emit a DeprecationWarning for tools that might be monitoring for it
115+
warnings.warn(
116+
"Both 'schema' and individual schema components are provided. Use only 'schema'.",
117+
DeprecationWarning,
118+
stacklevel=2,
119+
)
120+
elif any([entities, relations, potential_schema]):
121+
logger.warning(
122+
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated "
123+
"and will be removed in a future version. "
124+
"Please use the 'schema' parameter instead."
125+
)
126+
warnings.warn(
127+
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated.",
128+
DeprecationWarning,
129+
stacklevel=2,
130+
)
131+
132+
# handle schema precedence over individual schema components
133+
schema_entities = []
134+
schema_relations = []
135+
schema_potential = None
136+
137+
if schema is not None:
138+
# schema takes precedence over individual components
139+
if isinstance(schema, SchemaConfig):
140+
# use the SchemaConfig directly
141+
pass
142+
else:
143+
# convert dictionary to entity/relation lists
144+
schema_entities = schema.get("entities", [])
145+
schema_relations = schema.get("relations", [])
146+
schema_potential = schema.get("potential_schema")
147+
else:
148+
# Use the individual components if provided
149+
schema_entities = entities or []
150+
schema_relations = relations or []
151+
schema_potential = potential_schema
152+
153+
# determine if automatic schema extraction should be performed
154+
has_schema = bool(schema_entities or schema_relations or schema_potential or isinstance(schema, SchemaConfig))
155+
auto_schema_extraction = not has_schema
156+
99157
try:
100158
config = SimpleKGPipelineConfig(
101159
# argument type are fixed in the Config object
102160
llm_config=llm, # type: ignore[arg-type]
103161
neo4j_config=driver, # type: ignore[arg-type]
104162
embedder_config=embedder, # type: ignore[arg-type]
105-
entities=entities or [],
106-
relations=relations or [],
107-
potential_schema=potential_schema,
163+
entities=schema_entities,
164+
relations=schema_relations,
165+
potential_schema=schema_potential,
108166
enforce_schema=SchemaEnforcementMode(enforce_schema),
109167
from_pdf=from_pdf,
110168
pdf_loader=ComponentType(pdf_loader) if pdf_loader else None,
@@ -115,6 +173,7 @@ def __init__(
115173
perform_entity_resolution=perform_entity_resolution,
116174
lexical_graph_config=lexical_graph_config,
117175
neo4j_database=neo4j_database,
176+
auto_schema_extraction=auto_schema_extraction,
118177
)
119178
except (ValidationError, ValueError) as e:
120179
raise PipelineDefinitionError() from e

0 commit comments

Comments
 (0)