Skip to content

Commit f2a942c

Browse files
committed
WIP: apply constraints from DB
1 parent 2fc620f commit f2a942c

File tree

8 files changed

+250
-64
lines changed

8 files changed

+250
-64
lines changed

examples/customize/build_graph/components/schema_builders/schema.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,59 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import asyncio
16+
17+
import neo4j
18+
1519
from neo4j_graphrag.experimental.components.schema import (
1620
SchemaBuilder,
1721
)
18-
from neo4j_graphrag.experimental.components.types import (
19-
NodeType,
20-
PropertyType,
21-
RelationshipType,
22-
)
2322

2423

2524
async def main() -> None:
26-
schema_builder = SchemaBuilder()
25+
with neo4j.GraphDatabase.driver(
26+
"bolt://localhost:7687",
27+
auth=("neo4j", "password"),
28+
) as driver:
29+
schema_builder = SchemaBuilder(driver)
30+
31+
schema = await schema_builder.run(
32+
node_types=[
33+
{
34+
"label": "Person",
35+
"properties": [
36+
{"name": "name", "type": "STRING"},
37+
{"name": "place_of_birth", "type": "STRING"},
38+
{"name": "date_of_birth", "type": "DATE"},
39+
],
40+
},
41+
{
42+
"label": "Organization",
43+
"properties": [
44+
{"name": "name", "type": "STRING"},
45+
{"name": "country", "type": "STRING"},
46+
],
47+
},
48+
{
49+
"label": "Field",
50+
"properties": [
51+
{"name": "name", "type": "STRING"},
52+
],
53+
},
54+
],
55+
relationship_types=[
56+
"WORKED_ON",
57+
{
58+
"label": "WORKED_FOR",
59+
},
60+
],
61+
patterns=[
62+
("Person", "WORKED_ON", "Field"),
63+
("Person", "WORKED_FOR", "Organization"),
64+
],
65+
)
66+
print(schema)
67+
2768

28-
result = await schema_builder.run(
29-
node_types=[
30-
NodeType(
31-
label="Person",
32-
properties=[
33-
PropertyType(name="name", type="STRING"),
34-
PropertyType(name="place_of_birth", type="STRING"),
35-
PropertyType(name="date_of_birth", type="DATE"),
36-
],
37-
),
38-
NodeType(
39-
label="Organization",
40-
properties=[
41-
PropertyType(name="name", type="STRING"),
42-
PropertyType(name="country", type="STRING"),
43-
],
44-
),
45-
],
46-
relationship_types=[
47-
RelationshipType(
48-
label="WORKED_ON",
49-
),
50-
RelationshipType(
51-
label="WORKED_FOR",
52-
),
53-
],
54-
patterns=[
55-
("Person", "WORKED_ON", "Field"),
56-
("Person", "WORKED_FOR", "Organization"),
57-
],
58-
)
59-
print(result)
69+
if __name__ == "__main__":
70+
asyncio.run(main())

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525

2626
from neo4j_graphrag.exceptions import LLMGenerationError
2727
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
28-
from neo4j_graphrag.experimental.components.schema import GraphSchema
2928
from neo4j_graphrag.experimental.components.types import (
3029
DocumentInfo,
3130
LexicalGraphConfig,
31+
GraphSchema,
3232
Neo4jGraph,
3333
TextChunk,
3434
TextChunks,

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import copy
1718
import json
1819
import logging
19-
from typing import Any, Dict, List, Optional, Tuple, Sequence
20+
from typing import Any, Dict, List, Optional, Tuple, Sequence, Literal
2021

22+
import neo4j
2123
from pydantic import (
2224
validate_call,
2325
ValidationError,
@@ -29,13 +31,22 @@
2931
SchemaExtractionError,
3032
)
3133
from neo4j_graphrag.experimental.pipeline.component import Component
34+
from neo4j_graphrag.experimental.pipeline.types.schema import (
35+
EntityInputType,
36+
RelationInputType,
37+
)
3238
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
3339
from neo4j_graphrag.llm import LLMInterface
3440
from neo4j_graphrag.experimental.components.types import (
35-
NodeType,
3641
RelationshipType,
3742
GraphSchema,
43+
SchemaConstraint,
44+
ConstraintTypeEnum,
45+
Neo4jConstraintTypeEnum,
46+
GraphEntityType,
47+
Neo4jPropertyType,
3848
)
49+
from neo4j_graphrag.schema import get_constraints
3950

4051

4152
class SchemaBuilder(Component):
@@ -97,9 +108,83 @@ class SchemaBuilder(Component):
97108
pipe.run(pipe_inputs)
98109
"""
99110

111+
def __init__(
112+
self, driver: neo4j.Driver, neo4j_database: Optional[str] = None
113+
) -> None:
114+
self.driver = driver
115+
self.neo4j_database = neo4j_database
116+
117+
def _get_constraints_from_db(self) -> list[dict[str, Any]]:
118+
constraints = get_constraints(
119+
self.driver, database=self.neo4j_database, sanitize=False
120+
)
121+
return constraints
122+
123+
def _apply_all_constraints_from_db(
124+
self,
125+
node_or_relationship_type: Literal["NODE", "RELATIONSHIP"],
126+
constraints: list[dict[str, Any]],
127+
entities: tuple[GraphEntityType, ...],
128+
) -> list[GraphEntityType]:
129+
constrained_entity_types = []
130+
for entity_type in entities:
131+
new_entity_type = copy.deepcopy(entity_type)
132+
# find constraints related to this node type
133+
for constraint in constraints:
134+
if constraint["entityType"] != node_or_relationship_type:
135+
continue
136+
if constraint["labelsOrTypes"][0] != entity_type.label:
137+
continue
138+
# now we can add the constraint to this node type
139+
self._apply_constraint_from_db(new_entity_type, constraint)
140+
constrained_entity_types.append(new_entity_type)
141+
return constrained_entity_types
142+
100143
@staticmethod
101-
def create_schema_model(
102-
node_types: Sequence[NodeType],
144+
def _parse_property_type(property_type: str) -> Neo4jPropertyType | None:
145+
if not property_type:
146+
return None
147+
prop = None
148+
for prop_str in property_type.split("|"):
149+
p = prop_str.strip()
150+
try:
151+
prop = Neo4jPropertyType(p)
152+
except ValueError:
153+
pass
154+
return prop
155+
156+
def _apply_constraint_from_db(
157+
self, entity_type: GraphEntityType, constraint: dict[str, Any]
158+
) -> None:
159+
neo4j_constraint_type = Neo4jConstraintTypeEnum(constraint["type"])
160+
# TODO: detect potential conflict and raise ValueError if any
161+
# existing_schema_constraints_on_property = node_type.get_constraints_on_properties(constraint["properties"])
162+
constraint_properties = constraint["properties"]
163+
for p in constraint_properties:
164+
if entity_type.get_property_by_name(p) is None:
165+
raise ValueError(
166+
f"Can not add constraint {constraint} on non existing property"
167+
)
168+
constraint_type = neo4j_constraint_type.to_constraint_type()
169+
entity_type.constraints.append(
170+
SchemaConstraint(
171+
type=constraint_type,
172+
properties=constraint["properties"],
173+
property_type=self._parse_property_type(constraint["propertyType"]),
174+
name=constraint["name"],
175+
)
176+
)
177+
# if property required constraint, make sure the flag is set properly on
178+
# the PropertyType
179+
if constraint_type == ConstraintTypeEnum.PROPERTY_EXISTENCE:
180+
prop = entity_type.get_property_by_name(constraint["properties"][0])
181+
if prop:
182+
prop.required = True
183+
return None
184+
185+
def _create_schema_model(
186+
self,
187+
node_types: Sequence[EntityInputType],
103188
relationship_types: Optional[Sequence[RelationshipType]] = None,
104189
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
105190
**kwargs: Any,
@@ -118,7 +203,7 @@ def create_schema_model(
118203
GraphSchema: A configured schema object.
119204
"""
120205
try:
121-
return GraphSchema.model_validate(
206+
schema = GraphSchema.model_validate(
122207
dict(
123208
node_types=node_types,
124209
relationship_types=relationship_types or (),
@@ -129,11 +214,39 @@ def create_schema_model(
129214
except ValidationError as e:
130215
raise SchemaValidationError() from e
131216

217+
constraints = self._get_constraints_from_db()
218+
# apply constraints
219+
constrained_node_types = self._apply_all_constraints_from_db(
220+
"NODE",
221+
constraints,
222+
schema.node_types,
223+
)
224+
constrained_relationship_types = self._apply_all_constraints_from_db(
225+
"RELATIONSHIP",
226+
constraints,
227+
schema.relationship_types,
228+
)
229+
230+
try:
231+
constrained_schema = GraphSchema.model_validate(
232+
dict(
233+
node_types=constrained_node_types,
234+
relationship_types=constrained_relationship_types,
235+
patterns=patterns,
236+
**kwargs,
237+
)
238+
)
239+
except ValidationError as e:
240+
raise SchemaValidationError(
241+
"Error when applying constraints from database"
242+
) from e
243+
return constrained_schema
244+
132245
@validate_call
133246
async def run(
134247
self,
135-
node_types: Sequence[NodeType],
136-
relationship_types: Optional[Sequence[RelationshipType]] = None,
248+
node_types: Sequence[EntityInputType],
249+
relationship_types: Optional[Sequence[RelationInputType]] = None,
137250
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
138251
**kwargs: Any,
139252
) -> GraphSchema:
@@ -148,7 +261,7 @@ async def run(
148261
Returns:
149262
GraphSchema: A configured schema object, constructed asynchronously.
150263
"""
151-
return self.create_schema_model(
264+
return self._create_schema_model(
152265
node_types,
153266
relationship_types,
154267
patterns,
@@ -164,10 +277,12 @@ class SchemaFromTextExtractor(Component):
164277

165278
def __init__(
166279
self,
280+
driver: neo4j.Driver,
167281
llm: LLMInterface,
168282
prompt_template: Optional[PromptTemplate] = None,
169283
llm_params: Optional[Dict[str, Any]] = None,
170284
) -> None:
285+
self.driver = driver
171286
self._llm: LLMInterface = llm
172287
self._prompt_template: PromptTemplate = (
173288
prompt_template or SchemaExtractionTemplate()

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,6 @@ class PropertyType(BaseModel):
232232
description: str = ""
233233
required: bool = False
234234

235-
model_config = ConfigDict(
236-
frozen=True,
237-
)
238-
239235

240236
class ConstraintTypeEnum(str, enum.Enum):
241237
# see: https://neo4j.com/docs/cypher-manual/current/constraints/
@@ -245,6 +241,44 @@ class ConstraintTypeEnum(str, enum.Enum):
245241
PROPERTY_TYPE = "PROPERTY_TYPE"
246242

247243

244+
class Neo4jConstraintTypeEnum(str, enum.Enum):
245+
NODE_KEY = "NODE_KEY"
246+
UNIQUENESS = "UNIQUENESS"
247+
NODE_PROPERTY_EXISTENCE = "NODE_PROPERTY_EXISTENCE"
248+
NODE_PROPERTY_UNIQUENESS = "NODE_PROPERTY_UNIQUENESS"
249+
NODE_PROPERTY_TYPE = "NODE_PROPERTY_TYPE"
250+
RELATIONSHIP_KEY = "RELATIONSHIP_KEY"
251+
RELATIONSHIP_UNIQUENESS = "RELATIONSHIP_UNIQUENESS"
252+
RELATIONSHIP_PROPERTY_EXISTENCE = "RELATIONSHIP_PROPERTY_EXISTENCE"
253+
RELATIONSHIP_PROPERTY_UNIQUENESS = "RELATIONSHIP_PROPERTY_UNIQUENESS"
254+
RELATIONSHIP_PROPERTY_TYPE = "RELATIONSHIP_PROPERTY_TYPE"
255+
256+
def to_constraint_type(self) -> ConstraintTypeEnum:
257+
if self in (
258+
Neo4jConstraintTypeEnum.NODE_KEY,
259+
Neo4jConstraintTypeEnum.RELATIONSHIP_KEY,
260+
):
261+
return ConstraintTypeEnum.KEY
262+
if self in (
263+
Neo4jConstraintTypeEnum.UNIQUENESS,
264+
Neo4jConstraintTypeEnum.NODE_PROPERTY_UNIQUENESS,
265+
Neo4jConstraintTypeEnum.RELATIONSHIP_UNIQUENESS,
266+
Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_UNIQUENESS,
267+
):
268+
return ConstraintTypeEnum.PROPERTY_UNIQUENESS
269+
if self in (
270+
Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE,
271+
Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE,
272+
):
273+
return ConstraintTypeEnum.PROPERTY_EXISTENCE
274+
if self in (
275+
Neo4jConstraintTypeEnum.NODE_PROPERTY_TYPE,
276+
Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_TYPE,
277+
):
278+
return ConstraintTypeEnum.PROPERTY_TYPE
279+
raise ValueError(f"Can't convert {self} to ConstraintTypeEnum")
280+
281+
248282
class SchemaConstraint(BaseModel):
249283
"""Constraints that can be applied either on a node or relationship property."""
250284

@@ -290,6 +324,21 @@ def validate_constraint_on_properties(self) -> Self:
290324
)
291325
return self
292326

327+
def get_property_by_name(self, name: str) -> PropertyType | None:
328+
for prop in self.properties:
329+
if prop.name == name:
330+
return prop
331+
return None
332+
333+
def get_constraints_on_properties(
334+
self, prop_names: list[str]
335+
) -> list[SchemaConstraint]:
336+
constraints = []
337+
for constraint in self.constraints:
338+
if set(prop_names) == set(constraint.properties):
339+
constraints.append(constraint)
340+
return constraints
341+
293342
def get_unique_properties(self) -> list[str]:
294343
for c in self.constraints:
295344
if c.type in (

0 commit comments

Comments
 (0)