14
14
# limitations under the License.
15
15
from __future__ import annotations
16
16
17
+ import copy
17
18
import json
18
19
import logging
19
- from typing import Any , Dict , List , Optional , Tuple , Sequence
20
+ from typing import Any , Dict , List , Optional , Tuple , Sequence , Literal
20
21
22
+ import neo4j
21
23
from pydantic import (
22
24
validate_call ,
23
25
ValidationError ,
29
31
SchemaExtractionError ,
30
32
)
31
33
from neo4j_graphrag .experimental .pipeline .component import Component
34
+ from neo4j_graphrag .experimental .pipeline .types .schema import (
35
+ EntityInputType ,
36
+ RelationInputType ,
37
+ )
32
38
from neo4j_graphrag .generation import SchemaExtractionTemplate , PromptTemplate
33
39
from neo4j_graphrag .llm import LLMInterface
34
40
from neo4j_graphrag .experimental .components .types import (
35
- NodeType ,
36
41
RelationshipType ,
37
42
GraphSchema ,
43
+ SchemaConstraint ,
44
+ ConstraintTypeEnum ,
45
+ Neo4jConstraintTypeEnum ,
46
+ GraphEntityType ,
47
+ Neo4jPropertyType ,
38
48
)
49
+ from neo4j_graphrag .schema import get_constraints
39
50
40
51
41
52
class SchemaBuilder (Component ):
@@ -97,9 +108,83 @@ class SchemaBuilder(Component):
97
108
pipe.run(pipe_inputs)
98
109
"""
99
110
111
+ def __init__ (
112
+ self , driver : neo4j .Driver , neo4j_database : Optional [str ] = None
113
+ ) -> None :
114
+ self .driver = driver
115
+ self .neo4j_database = neo4j_database
116
+
117
+ def _get_constraints_from_db (self ) -> list [dict [str , Any ]]:
118
+ constraints = get_constraints (
119
+ self .driver , database = self .neo4j_database , sanitize = False
120
+ )
121
+ return constraints
122
+
123
+ def _apply_all_constraints_from_db (
124
+ self ,
125
+ node_or_relationship_type : Literal ["NODE" , "RELATIONSHIP" ],
126
+ constraints : list [dict [str , Any ]],
127
+ entities : tuple [GraphEntityType , ...],
128
+ ) -> list [GraphEntityType ]:
129
+ constrained_entity_types = []
130
+ for entity_type in entities :
131
+ new_entity_type = copy .deepcopy (entity_type )
132
+ # find constraints related to this node type
133
+ for constraint in constraints :
134
+ if constraint ["entityType" ] != node_or_relationship_type :
135
+ continue
136
+ if constraint ["labelsOrTypes" ][0 ] != entity_type .label :
137
+ continue
138
+ # now we can add the constraint to this node type
139
+ self ._apply_constraint_from_db (new_entity_type , constraint )
140
+ constrained_entity_types .append (new_entity_type )
141
+ return constrained_entity_types
142
+
100
143
@staticmethod
101
- def create_schema_model (
102
- node_types : Sequence [NodeType ],
144
+ def _parse_property_type (property_type : str ) -> Neo4jPropertyType | None :
145
+ if not property_type :
146
+ return None
147
+ prop = None
148
+ for prop_str in property_type .split ("|" ):
149
+ p = prop_str .strip ()
150
+ try :
151
+ prop = Neo4jPropertyType (p )
152
+ except ValueError :
153
+ pass
154
+ return prop
155
+
156
+ def _apply_constraint_from_db (
157
+ self , entity_type : GraphEntityType , constraint : dict [str , Any ]
158
+ ) -> None :
159
+ neo4j_constraint_type = Neo4jConstraintTypeEnum (constraint ["type" ])
160
+ # TODO: detect potential conflict and raise ValueError if any
161
+ # existing_schema_constraints_on_property = node_type.get_constraints_on_properties(constraint["properties"])
162
+ constraint_properties = constraint ["properties" ]
163
+ for p in constraint_properties :
164
+ if entity_type .get_property_by_name (p ) is None :
165
+ raise ValueError (
166
+ f"Can not add constraint { constraint } on non existing property"
167
+ )
168
+ constraint_type = neo4j_constraint_type .to_constraint_type ()
169
+ entity_type .constraints .append (
170
+ SchemaConstraint (
171
+ type = constraint_type ,
172
+ properties = constraint ["properties" ],
173
+ property_type = self ._parse_property_type (constraint ["propertyType" ]),
174
+ name = constraint ["name" ],
175
+ )
176
+ )
177
+ # if property required constraint, make sure the flag is set properly on
178
+ # the PropertyType
179
+ if constraint_type == ConstraintTypeEnum .PROPERTY_EXISTENCE :
180
+ prop = entity_type .get_property_by_name (constraint ["properties" ][0 ])
181
+ if prop :
182
+ prop .required = True
183
+ return None
184
+
185
+ def _create_schema_model (
186
+ self ,
187
+ node_types : Sequence [EntityInputType ],
103
188
relationship_types : Optional [Sequence [RelationshipType ]] = None ,
104
189
patterns : Optional [Sequence [Tuple [str , str , str ]]] = None ,
105
190
** kwargs : Any ,
@@ -118,7 +203,7 @@ def create_schema_model(
118
203
GraphSchema: A configured schema object.
119
204
"""
120
205
try :
121
- return GraphSchema .model_validate (
206
+ schema = GraphSchema .model_validate (
122
207
dict (
123
208
node_types = node_types ,
124
209
relationship_types = relationship_types or (),
@@ -129,11 +214,39 @@ def create_schema_model(
129
214
except ValidationError as e :
130
215
raise SchemaValidationError () from e
131
216
217
+ constraints = self ._get_constraints_from_db ()
218
+ # apply constraints
219
+ constrained_node_types = self ._apply_all_constraints_from_db (
220
+ "NODE" ,
221
+ constraints ,
222
+ schema .node_types ,
223
+ )
224
+ constrained_relationship_types = self ._apply_all_constraints_from_db (
225
+ "RELATIONSHIP" ,
226
+ constraints ,
227
+ schema .relationship_types ,
228
+ )
229
+
230
+ try :
231
+ constrained_schema = GraphSchema .model_validate (
232
+ dict (
233
+ node_types = constrained_node_types ,
234
+ relationship_types = constrained_relationship_types ,
235
+ patterns = patterns ,
236
+ ** kwargs ,
237
+ )
238
+ )
239
+ except ValidationError as e :
240
+ raise SchemaValidationError (
241
+ "Error when applying constraints from database"
242
+ ) from e
243
+ return constrained_schema
244
+
132
245
@validate_call
133
246
async def run (
134
247
self ,
135
- node_types : Sequence [NodeType ],
136
- relationship_types : Optional [Sequence [RelationshipType ]] = None ,
248
+ node_types : Sequence [EntityInputType ],
249
+ relationship_types : Optional [Sequence [RelationInputType ]] = None ,
137
250
patterns : Optional [Sequence [Tuple [str , str , str ]]] = None ,
138
251
** kwargs : Any ,
139
252
) -> GraphSchema :
@@ -148,7 +261,7 @@ async def run(
148
261
Returns:
149
262
GraphSchema: A configured schema object, constructed asynchronously.
150
263
"""
151
- return self .create_schema_model (
264
+ return self ._create_schema_model (
152
265
node_types ,
153
266
relationship_types ,
154
267
patterns ,
@@ -164,10 +277,12 @@ class SchemaFromTextExtractor(Component):
164
277
165
278
def __init__ (
166
279
self ,
280
+ driver : neo4j .Driver ,
167
281
llm : LLMInterface ,
168
282
prompt_template : Optional [PromptTemplate ] = None ,
169
283
llm_params : Optional [Dict [str , Any ]] = None ,
170
284
) -> None :
285
+ self .driver = driver
171
286
self ._llm : LLMInterface = llm
172
287
self ._prompt_template : PromptTemplate = (
173
288
prompt_template or SchemaExtractionTemplate ()
0 commit comments