12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
- from typing import Any , ClassVar , Literal , Optional , Sequence , Union , TypeVar
15
+ from typing import Any , ClassVar , Literal , Optional , Sequence , Union , List , Tuple
16
16
import logging
17
17
18
- from pydantic import ConfigDict , model_validator
18
+ from pydantic import ConfigDict , Field , model_validator
19
+ from typing_extensions import Self
19
20
20
21
from neo4j_graphrag .experimental .components .embedder import TextChunkEmbedder
21
22
from neo4j_graphrag .experimental .components .entity_relation_extractor import (
59
60
60
61
logger = logging .getLogger (__name__ )
61
62
62
- T = TypeVar ("T" , bound = "SimpleKGPipelineConfig" )
63
-
64
63
65
64
class SimpleKGPipelineConfig (TemplatePipelineConfig ):
66
65
COMPONENTS : ClassVar [list [str ]] = [
@@ -81,7 +80,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
81
80
entities : Sequence [EntityInputType ] = []
82
81
relations : Sequence [RelationInputType ] = []
83
82
potential_schema : Optional [list [tuple [str , str , str ]]] = None
84
- schema : Optional [Union [SchemaConfig , dict [str , list [Any ]]]] = None # type: ignore
83
+ schema_ : Optional [Union [SchemaConfig , dict [str , list [Any ]]]] = Field ( default = None , alias = "schema" )
85
84
enforce_schema : SchemaEnforcementMode = SchemaEnforcementMode .NONE
86
85
on_error : OnError = OnError .IGNORE
87
86
prompt_template : Union [ERExtractionTemplate , str ] = ERExtractionTemplate ()
@@ -97,10 +96,10 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
97
96
model_config = ConfigDict (arbitrary_types_allowed = True )
98
97
99
98
@model_validator (mode = "after" )
100
- def handle_schema_precedence (self ) -> T : # type: ignore
99
+ def handle_schema_precedence (self ) -> Self :
101
100
"""Handle schema precedence and warnings"""
102
101
self ._process_schema_parameters ()
103
- return self # type: ignore
102
+ return self
104
103
105
104
def _process_schema_parameters (self ) -> None :
106
105
"""
@@ -112,7 +111,7 @@ def _process_schema_parameters(self) -> None:
112
111
[self .entities , self .relations , self .potential_schema ]
113
112
)
114
113
115
- if has_individual_schema_components and self .schema is not None :
114
+ if has_individual_schema_components and self .schema_ is not None :
116
115
logger .warning (
117
116
"Both 'schema' and individual schema components (entities, relations, potential_schema) "
118
117
"were provided. The 'schema' parameter takes precedence. In the future, individual "
@@ -134,7 +133,7 @@ def has_user_provided_schema(self) -> bool:
134
133
self .entities
135
134
or self .relations
136
135
or self .potential_schema
137
- or self .schema is not None
136
+ or self .schema_ is not None
138
137
)
139
138
140
139
def _get_pdf_loader (self ) -> Optional [PdfLoader ]:
@@ -175,8 +174,8 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
175
174
176
175
def _process_schema_with_precedence (
177
176
self ,
178
- ) -> tuple [
179
- list [SchemaEntity ], list [SchemaRelation ], Optional [list [ tuple [str , str , str ]]]
177
+ ) -> Tuple [
178
+ List [SchemaEntity ], List [SchemaRelation ], Optional [List [ Tuple [str , str , str ]]]
180
179
]:
181
180
"""
182
181
Process schema inputs according to precedence rules:
@@ -187,28 +186,37 @@ def _process_schema_with_precedence(
187
186
Returns:
188
187
Tuple of (entities, relations, potential_schema)
189
188
"""
190
- if self .schema is not None :
189
+ if self .schema_ is not None :
191
190
# schema takes precedence over individual components
192
- if isinstance (self .schema , SchemaConfig ):
191
+ if isinstance (self .schema_ , SchemaConfig ):
193
192
# extract components from SchemaConfig
194
- entities = list (self .schema .entities .values ())
195
- relations = list (self .schema .relations .values ()) # type: ignore
196
- potential_schema = self .schema .potential_schema
193
+ entity_dicts = list (self .schema_ .entities .values ())
194
+ # convert dict values to SchemaEntity objects
195
+ entities = [SchemaEntity .model_validate (e ) for e in entity_dicts ]
196
+
197
+ # handle case where relations could be None
198
+ if self .schema_ .relations is not None :
199
+ relation_dicts = list (self .schema_ .relations .values ())
200
+ relations = [SchemaRelation .model_validate (r ) for r in relation_dicts ]
201
+ else :
202
+ relations = []
203
+
204
+ potential_schema = self .schema_ .potential_schema
197
205
else :
198
206
# extract from dictionary
199
207
entities = [
200
- SchemaEntity .from_text_or_dict (e ) # type: ignore
201
- for e in self .schema .get ("entities" , [])
208
+ SchemaEntity .from_text_or_dict (e )
209
+ for e in self .schema_ .get ("entities" , [])
202
210
]
203
211
relations = [
204
212
SchemaRelation .from_text_or_dict (r )
205
- for r in self .schema .get ("relations" , [])
213
+ for r in self .schema_ .get ("relations" , [])
206
214
]
207
- potential_schema = self .schema .get ("potential_schema" )
215
+ potential_schema = self .schema_ .get ("potential_schema" )
208
216
else :
209
217
# use individual components
210
218
entities = (
211
- [SchemaEntity .from_text_or_dict (e ) for e in self .entities ] # type: ignore
219
+ [SchemaEntity .from_text_or_dict (e ) for e in self .entities ]
212
220
if self .entities
213
221
else []
214
222
)
@@ -219,7 +227,7 @@ def _process_schema_with_precedence(
219
227
)
220
228
potential_schema = self .potential_schema
221
229
222
- return entities , relations , potential_schema # type: ignore
230
+ return entities , relations , potential_schema
223
231
224
232
def _get_run_params_for_schema (self ) -> dict [str , Any ]:
225
233
if self .auto_schema_extraction and not self .has_user_provided_schema ():
0 commit comments