12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
- from typing import Any , ClassVar , Literal , Optional , Sequence , Union
15
+ from typing import Any , ClassVar , Literal , Optional , Sequence , Union , TypeVar
16
+ import logging
16
17
17
- from pydantic import ConfigDict
18
+ from pydantic import ConfigDict , model_validator , Field
18
19
19
20
from neo4j_graphrag .experimental .components .embedder import TextChunkEmbedder
20
21
from neo4j_graphrag .experimental .components .entity_relation_extractor import (
33
34
SchemaEntity ,
34
35
SchemaRelation ,
35
36
SchemaFromText ,
37
+ SchemaConfig ,
36
38
)
37
39
from neo4j_graphrag .experimental .components .text_splitters .base import TextSplitter
38
40
from neo4j_graphrag .experimental .components .text_splitters .fixed_size_splitter import (
55
57
)
56
58
from neo4j_graphrag .generation .prompts import ERExtractionTemplate
57
59
60
+ logger = logging .getLogger (__name__ )
61
+
62
+ T = TypeVar ('T' , bound = 'SimpleKGPipelineConfig' )
58
63
59
64
class SimpleKGPipelineConfig (TemplatePipelineConfig ):
60
65
COMPONENTS : ClassVar [list [str ]] = [
@@ -75,6 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
75
80
entities : Sequence [EntityInputType ] = []
76
81
relations : Sequence [RelationInputType ] = []
77
82
potential_schema : Optional [list [tuple [str , str , str ]]] = None
83
+ schema : Optional [Union [SchemaConfig , dict [str , list ]]] = None
78
84
enforce_schema : SchemaEnforcementMode = SchemaEnforcementMode .NONE
79
85
on_error : OnError = OnError .IGNORE
80
86
prompt_template : Union [ERExtractionTemplate , str ] = ERExtractionTemplate ()
@@ -88,10 +94,40 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
88
94
text_splitter : Optional [ComponentType ] = None
89
95
90
96
model_config = ConfigDict (arbitrary_types_allowed = True )
97
+
98
+ @model_validator (mode = 'after' )
99
+ def handle_schema_precedence (self ) -> T :
100
+ """Handle schema precedence and warnings"""
101
+ self ._process_schema_parameters ()
102
+ return self
103
+
104
+ def _process_schema_parameters (self ) -> None :
105
+ """
106
+ Process schema parameters and handle precedence between 'schema' parameter and individual components.
107
+ Also logs warnings for deprecated usage.
108
+ """
109
+ # check if both schema and individual components are provided
110
+ has_individual_schema_components = any ([self .entities , self .relations , self .potential_schema ])
111
+
112
+ if has_individual_schema_components and self .schema is not None :
113
+ logger .warning (
114
+ "Both 'schema' and individual schema components (entities, relations, potential_schema) "
115
+ "were provided. The 'schema' parameter takes precedence. In the future, individual "
116
+ "components will be removed. Please use only the 'schema' parameter." ,
117
+ stacklevel = 2
118
+ )
119
+
120
+ elif has_individual_schema_components :
121
+ logger .warning (
122
+ "The 'entities', 'relations', and 'potential_schema' parameters are deprecated "
123
+ "and will be removed in a future version. "
124
+ "Please use the 'schema' parameter instead." ,
125
+ stacklevel = 2
126
+ )
91
127
92
128
def has_user_provided_schema (self ) -> bool :
93
129
"""Check if the user has provided schema information"""
94
- return bool (self .entities or self .relations or self .potential_schema )
130
+ return bool (self .entities or self .relations or self .potential_schema or self . schema is not None )
95
131
96
132
def _get_pdf_loader (self ) -> Optional [PdfLoader ]:
97
133
if not self .from_pdf :
@@ -129,16 +165,48 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromText]:
129
165
return SchemaFromText (llm = self .get_default_llm ())
130
166
return SchemaBuilder ()
131
167
168
+ def _process_schema_with_precedence (self ) -> tuple [list [SchemaEntity ], list [SchemaRelation ], Optional [list [tuple [str , str , str ]]]]:
169
+ """
170
+ Process schema inputs according to precedence rules:
171
+ 1. If schema is provided as SchemaConfig object, use it
172
+ 2. If schema is provided as dictionary, extract from it
173
+ 3. Otherwise, use individual schema components
174
+
175
+ Returns:
176
+ Tuple of (entities, relations, potential_schema)
177
+ """
178
+ if self .schema is not None :
179
+ # schema takes precedence over individual components
180
+ if isinstance (self .schema , SchemaConfig ):
181
+ # extract components from SchemaConfig
182
+ entities = list (self .schema .entities .values ())
183
+ relations = list (self .schema .relations .values ())
184
+ potential_schema = self .schema .potential_schema
185
+ else :
186
+ # extract from dictionary
187
+ entities = [SchemaEntity .from_text_or_dict (e ) for e in self .schema .get ("entities" , [])]
188
+ relations = [SchemaRelation .from_text_or_dict (r ) for r in self .schema .get ("relations" , [])]
189
+ potential_schema = self .schema .get ("potential_schema" )
190
+ else :
191
+ # use individual components
192
+ entities = [SchemaEntity .from_text_or_dict (e ) for e in self .entities ] if self .entities else []
193
+ relations = [SchemaRelation .from_text_or_dict (r ) for r in self .relations ] if self .relations else []
194
+ potential_schema = self .potential_schema
195
+
196
+ return entities , relations , potential_schema
197
+
132
198
def _get_run_params_for_schema (self ) -> dict [str , Any ]:
133
199
if self .auto_schema_extraction and not self .has_user_provided_schema ():
134
200
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
135
201
return {}
136
202
else :
137
- # for manual schema, use the provided entities/relations/potential_schema
203
+ # process schema components according to precedence rules
204
+ entities , relations , potential_schema = self ._process_schema_with_precedence ()
205
+
138
206
return {
139
- "entities" : [ SchemaEntity . from_text_or_dict ( e ) for e in self . entities ] ,
140
- "relations" : [ SchemaRelation . from_text_or_dict ( r ) for r in self . relations ] ,
141
- "potential_schema" : self . potential_schema ,
207
+ "entities" : entities ,
208
+ "relations" : relations ,
209
+ "potential_schema" : potential_schema ,
142
210
}
143
211
144
212
def _get_extractor (self ) -> EntityRelationExtractor :
0 commit comments