Skip to content

Commit 41d359d

Browse files
Allow schema parameter in SimpleKGBuilderConfig and refactor code
1 parent b52bed4 commit 41d359d

File tree

2 files changed

+81
-63
lines changed

2 files changed

+81
-63
lines changed

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, ClassVar, Literal, Optional, Sequence, Union
15+
from typing import Any, ClassVar, Literal, Optional, Sequence, Union, TypeVar
16+
import logging
1617

17-
from pydantic import ConfigDict
18+
from pydantic import ConfigDict, model_validator, Field
1819

1920
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
2021
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
@@ -33,6 +34,7 @@
3334
SchemaEntity,
3435
SchemaRelation,
3536
SchemaFromText,
37+
SchemaConfig,
3638
)
3739
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
3840
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
@@ -55,6 +57,9 @@
5557
)
5658
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
5759

60+
logger = logging.getLogger(__name__)
61+
62+
T = TypeVar('T', bound='SimpleKGPipelineConfig')
5863

5964
class SimpleKGPipelineConfig(TemplatePipelineConfig):
6065
COMPONENTS: ClassVar[list[str]] = [
@@ -75,6 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
7580
entities: Sequence[EntityInputType] = []
7681
relations: Sequence[RelationInputType] = []
7782
potential_schema: Optional[list[tuple[str, str, str]]] = None
83+
schema: Optional[Union[SchemaConfig, dict[str, list]]] = None
7884
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE
7985
on_error: OnError = OnError.IGNORE
8086
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
@@ -88,10 +94,40 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
8894
text_splitter: Optional[ComponentType] = None
8995

9096
model_config = ConfigDict(arbitrary_types_allowed=True)
97+
98+
@model_validator(mode='after')
99+
def handle_schema_precedence(self) -> T:
100+
"""Handle schema precedence and warnings"""
101+
self._process_schema_parameters()
102+
return self
103+
104+
def _process_schema_parameters(self) -> None:
105+
"""
106+
Process schema parameters and handle precedence between 'schema' parameter and individual components.
107+
Also logs warnings for deprecated usage.
108+
"""
109+
# check if both schema and individual components are provided
110+
has_individual_schema_components = any([self.entities, self.relations, self.potential_schema])
111+
112+
if has_individual_schema_components and self.schema is not None:
113+
logger.warning(
114+
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
115+
"were provided. The 'schema' parameter takes precedence. In the future, individual "
116+
"components will be removed. Please use only the 'schema' parameter.",
117+
stacklevel=2
118+
)
119+
120+
elif has_individual_schema_components:
121+
logger.warning(
122+
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated "
123+
"and will be removed in a future version. "
124+
"Please use the 'schema' parameter instead.",
125+
stacklevel=2
126+
)
91127

92128
def has_user_provided_schema(self) -> bool:
93129
"""Check if the user has provided schema information"""
94-
return bool(self.entities or self.relations or self.potential_schema)
130+
return bool(self.entities or self.relations or self.potential_schema or self.schema is not None)
95131

96132
def _get_pdf_loader(self) -> Optional[PdfLoader]:
97133
if not self.from_pdf:
@@ -129,16 +165,48 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]:
129165
return SchemaFromText(llm=self.get_default_llm())
130166
return SchemaBuilder()
131167

168+
def _process_schema_with_precedence(self) -> tuple[list[SchemaEntity], list[SchemaRelation], Optional[list[tuple[str, str, str]]]]:
169+
"""
170+
Process schema inputs according to precedence rules:
171+
1. If schema is provided as SchemaConfig object, use it
172+
2. If schema is provided as dictionary, extract from it
173+
3. Otherwise, use individual schema components
174+
175+
Returns:
176+
Tuple of (entities, relations, potential_schema)
177+
"""
178+
if self.schema is not None:
179+
# schema takes precedence over individual components
180+
if isinstance(self.schema, SchemaConfig):
181+
# extract components from SchemaConfig
182+
entities = list(self.schema.entities.values())
183+
relations = list(self.schema.relations.values())
184+
potential_schema = self.schema.potential_schema
185+
else:
186+
# extract from dictionary
187+
entities = [SchemaEntity.from_text_or_dict(e) for e in self.schema.get("entities", [])]
188+
relations = [SchemaRelation.from_text_or_dict(r) for r in self.schema.get("relations", [])]
189+
potential_schema = self.schema.get("potential_schema")
190+
else:
191+
# use individual components
192+
entities = [SchemaEntity.from_text_or_dict(e) for e in self.entities] if self.entities else []
193+
relations = [SchemaRelation.from_text_or_dict(r) for r in self.relations] if self.relations else []
194+
potential_schema = self.potential_schema
195+
196+
return entities, relations, potential_schema
197+
132198
def _get_run_params_for_schema(self) -> dict[str, Any]:
133199
if self.auto_schema_extraction and not self.has_user_provided_schema():
134200
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
135201
return {}
136202
else:
137-
# for manual schema, use the provided entities/relations/potential_schema
203+
# process schema components according to precedence rules
204+
entities, relations, potential_schema = self._process_schema_with_precedence()
205+
138206
return {
139-
"entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities],
140-
"relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations],
141-
"potential_schema": self.potential_schema,
207+
"entities": entities,
208+
"relations": relations,
209+
"potential_schema": potential_schema,
142210
}
143211

144212
def _get_extractor(self) -> EntityRelationExtractor:

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 6 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from typing import List, Optional, Sequence, Union
1919
import logging
20-
import warnings
2120

2221
import neo4j
2322
from pydantic import ValidationError
@@ -44,7 +43,7 @@
4443
)
4544
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
4645
from neo4j_graphrag.llm.base import LLMInterface
47-
from neo4j_graphrag.experimental.components.schema import SchemaConfig, SchemaBuilder
46+
from neo4j_graphrag.experimental.components.schema import SchemaConfig
4847

4948
logger = logging.getLogger(__name__)
5049

@@ -104,65 +103,16 @@ def __init__(
104103
lexical_graph_config: Optional[LexicalGraphConfig] = None,
105104
neo4j_database: Optional[str] = None,
106105
):
107-
# deprecation warnings for old parameters
108-
if any([entities, relations, potential_schema]) and schema is not None:
109-
logger.warning(
110-
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
111-
"were provided. The 'schema' parameter takes precedence. In the future, individual "
112-
"components will be removed. Please use only the 'schema' parameter."
113-
)
114-
# emit a DeprecationWarning for tools that might be monitoring for it
115-
warnings.warn(
116-
"Both 'schema' and individual schema components are provided. Use only 'schema'.",
117-
DeprecationWarning,
118-
stacklevel=2,
119-
)
120-
elif any([entities, relations, potential_schema]):
121-
logger.warning(
122-
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated "
123-
"and will be removed in a future version. "
124-
"Please use the 'schema' parameter instead."
125-
)
126-
warnings.warn(
127-
"The 'entities', 'relations', and 'potential_schema' parameters are deprecated.",
128-
DeprecationWarning,
129-
stacklevel=2,
130-
)
131-
132-
# handle schema precedence over individual schema components
133-
schema_entities = []
134-
schema_relations = []
135-
schema_potential = None
136-
137-
if schema is not None:
138-
# schema takes precedence over individual components
139-
if isinstance(schema, SchemaConfig):
140-
# use the SchemaConfig directly
141-
pass
142-
else:
143-
# convert dictionary to entity/relation lists
144-
schema_entities = schema.get("entities", [])
145-
schema_relations = schema.get("relations", [])
146-
schema_potential = schema.get("potential_schema")
147-
else:
148-
# Use the individual components if provided
149-
schema_entities = entities or []
150-
schema_relations = relations or []
151-
schema_potential = potential_schema
152-
153-
# determine if automatic schema extraction should be performed
154-
has_schema = bool(schema_entities or schema_relations or schema_potential or isinstance(schema, SchemaConfig))
155-
auto_schema_extraction = not has_schema
156-
157106
try:
158107
config = SimpleKGPipelineConfig(
159108
# argument type are fixed in the Config object
160109
llm_config=llm, # type: ignore[arg-type]
161110
neo4j_config=driver, # type: ignore[arg-type]
162111
embedder_config=embedder, # type: ignore[arg-type]
163-
entities=schema_entities,
164-
relations=schema_relations,
165-
potential_schema=schema_potential,
112+
entities=entities or [],
113+
relations=relations or [],
114+
potential_schema=potential_schema,
115+
schema=schema,
166116
enforce_schema=SchemaEnforcementMode(enforce_schema),
167117
from_pdf=from_pdf,
168118
pdf_loader=ComponentType(pdf_loader) if pdf_loader else None,
@@ -173,7 +123,7 @@ def __init__(
173123
perform_entity_resolution=perform_entity_resolution,
174124
lexical_graph_config=lexical_graph_config,
175125
neo4j_database=neo4j_database,
176-
auto_schema_extraction=auto_schema_extraction,
126+
auto_schema_extraction=not bool(schema or entities or relations or potential_schema),
177127
)
178128
except (ValidationError, ValueError) as e:
179129
raise PipelineDefinitionError() from e

0 commit comments

Comments
 (0)