Skip to content

Commit 61c4f1a

Browse files
committed
WIP: refactoring and fix mypy
1 parent f2a942c commit 61c4f1a

File tree

16 files changed

+181
-151
lines changed

16 files changed

+181
-151
lines changed

examples/customize/build_graph/components/pruners/graph_pruner.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Neo4jGraph,
1212
Neo4jNode,
1313
Neo4jRelationship,
14+
Neo4jPropertyType,
1415
)
1516

1617
graph = Neo4jGraph(
@@ -66,24 +67,30 @@
6667
NodeType(
6768
label="Person",
6869
properties=[
69-
PropertyType(name="firstName", type="STRING", required=True),
70-
PropertyType(name="lastName", type="STRING", required=True),
71-
PropertyType(name="age", type="INTEGER"),
70+
PropertyType(
71+
name="firstName", type=Neo4jPropertyType.STRING, required=True
72+
),
73+
PropertyType(
74+
name="lastName", type=Neo4jPropertyType.STRING, required=True
75+
),
76+
PropertyType(name="age", type=Neo4jPropertyType.INTEGER),
7277
],
7378
additional_properties=False,
7479
),
7580
NodeType(
7681
label="Organization",
7782
properties=[
78-
PropertyType(name="name", type="STRING", required=True),
79-
PropertyType(name="address", type="STRING"),
83+
PropertyType(name="name", type=Neo4jPropertyType.STRING, required=True),
84+
PropertyType(name="address", type=Neo4jPropertyType.STRING),
8085
],
8186
),
8287
),
8388
relationship_types=(
8489
RelationshipType(
8590
label="WORKS_FOR",
86-
properties=[PropertyType(name="since", type="LOCAL_DATETIME")],
91+
properties=[
92+
PropertyType(name="since", type=Neo4jPropertyType.LOCAL_DATETIME)
93+
],
8794
),
8895
RelationshipType(
8996
label="KNOWS",

examples/customize/build_graph/components/schema_builders/schema_from_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def extract_and_save_schema() -> None:
7676

7777
try:
7878
# Create a SchemaFromTextExtractor component with the default template
79-
schema_extractor = SchemaFromTextExtractor(llm=llm)
79+
schema_extractor = SchemaFromTextExtractor(driver=None, llm=llm)
8080

8181
print("Extracting schema from text...")
8282
# Extract schema from text

examples/customize/build_graph/pipeline/kg_builder_from_pdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
2626
from neo4j_graphrag.experimental.components.schema import (
2727
SchemaBuilder,
28+
)
29+
from neo4j_graphrag.experimental.components.types import (
2830
NodeType,
2931
RelationshipType,
3032
)
@@ -87,7 +89,7 @@ async def define_and_run_pipeline(
8789
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200, approximate=False),
8890
"splitter",
8991
)
90-
pipe.add_component(SchemaBuilder(), "schema")
92+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
9193
pipe.add_component(
9294
LLMEntityRelationExtractor(
9395
llm=llm,

examples/customize/build_graph/pipeline/kg_builder_from_text.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
2626
from neo4j_graphrag.experimental.components.schema import (
2727
SchemaBuilder,
28+
)
29+
from neo4j_graphrag.experimental.components.types import (
2830
NodeType,
2931
PropertyType,
3032
RelationshipType,
33+
Neo4jPropertyType,
3134
)
3235
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
3336
FixedSizeSplitter,
@@ -63,7 +66,7 @@ async def define_and_run_pipeline(
6366
"splitter",
6467
)
6568
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
66-
pipe.add_component(SchemaBuilder(), "schema")
69+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
6770
pipe.add_component(
6871
LLMEntityRelationExtractor(
6972
llm=llm,
@@ -99,22 +102,24 @@ async def define_and_run_pipeline(
99102
NodeType(
100103
label="Person",
101104
properties=[
102-
PropertyType(name="name", type="STRING"),
103-
PropertyType(name="place_of_birth", type="STRING"),
104-
PropertyType(name="date_of_birth", type="DATE"),
105+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
106+
PropertyType(
107+
name="place_of_birth", type=Neo4jPropertyType.STRING
108+
),
109+
PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE),
105110
],
106111
),
107112
NodeType(
108113
label="Organization",
109114
properties=[
110-
PropertyType(name="name", type="STRING"),
111-
PropertyType(name="country", type="STRING"),
115+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
116+
PropertyType(name="country", type=Neo4jPropertyType.STRING),
112117
],
113118
),
114119
NodeType(
115120
label="Field",
116121
properties=[
117-
PropertyType(name="name", type="STRING"),
122+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
118123
],
119124
),
120125
],

examples/customize/build_graph/pipeline/kg_builder_two_documents_entity_resolution.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
)
2929
from neo4j_graphrag.experimental.components.schema import (
3030
SchemaBuilder,
31+
)
32+
from neo4j_graphrag.experimental.components.types import (
3133
NodeType,
3234
PropertyType,
3335
RelationshipType,
36+
Neo4jPropertyType,
3437
)
3538
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
3639
FixedSizeSplitter,
@@ -61,7 +64,7 @@ async def define_and_run_pipeline(
6164
FixedSizeSplitter(),
6265
"splitter",
6366
)
64-
pipe.add_component(SchemaBuilder(), "schema")
67+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
6568
pipe.add_component(
6669
LLMEntityRelationExtractor(
6770
llm=llm,
@@ -96,16 +99,18 @@ async def define_and_run_pipeline(
9699
NodeType(
97100
label="Person",
98101
properties=[
99-
PropertyType(name="name", type="STRING"),
100-
PropertyType(name="place_of_birth", type="STRING"),
101-
PropertyType(name="date_of_birth", type="DATE"),
102+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
103+
PropertyType(
104+
name="place_of_birth", type=Neo4jPropertyType.STRING
105+
),
106+
PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE),
102107
],
103108
),
104109
NodeType(
105110
label="Organization",
106111
properties=[
107-
PropertyType(name="name", type="STRING"),
108-
PropertyType(name="country", type="STRING"),
112+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
113+
PropertyType(name="country", type=Neo4jPropertyType.STRING),
109114
],
110115
),
111116
],

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_single_pipeline.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
1717
from neo4j_graphrag.experimental.components.schema import (
1818
SchemaBuilder,
19+
)
20+
from neo4j_graphrag.experimental.components.types import (
1921
NodeType,
2022
PropertyType,
21-
RelationshipType,
23+
RelationshipType, Neo4jPropertyType,
2224
)
2325
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
2426
FixedSizeSplitter,
@@ -66,7 +68,7 @@ async def define_and_run_pipeline(
6668
"lexical_graph_builder",
6769
)
6870
pipe.add_component(Neo4jWriter(neo4j_driver), "lg_writer")
69-
pipe.add_component(SchemaBuilder(), "schema")
71+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
7072
pipe.add_component(
7173
LLMEntityRelationExtractor(
7274
llm=llm,
@@ -122,22 +124,22 @@ async def define_and_run_pipeline(
122124
NodeType(
123125
label="Person",
124126
properties=[
125-
PropertyType(name="name", type="STRING"),
126-
PropertyType(name="place_of_birth", type="STRING"),
127-
PropertyType(name="date_of_birth", type="DATE"),
127+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
128+
PropertyType(name="place_of_birth", type=Neo4jPropertyType.STRING),
129+
PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE),
128130
],
129131
),
130132
NodeType(
131133
label="Organization",
132134
properties=[
133-
PropertyType(name="name", type="STRING"),
134-
PropertyType(name="country", type="STRING"),
135+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
136+
PropertyType(name="country", type=Neo4jPropertyType.STRING),
135137
],
136138
),
137139
NodeType(
138140
label="Field",
139141
properties=[
140-
PropertyType(name="name", type="STRING"),
142+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
141143
],
142144
),
143145
],

examples/customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader
1919
from neo4j_graphrag.experimental.components.schema import (
2020
SchemaBuilder,
21+
)
22+
from neo4j_graphrag.experimental.components.types import (
2123
NodeType,
2224
PropertyType,
23-
RelationshipType,
25+
RelationshipType, Neo4jPropertyType,
2426
)
2527
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
2628
FixedSizeSplitter,
@@ -112,7 +114,7 @@ async def read_chunk_and_perform_entity_extraction(
112114
pipe = Pipeline()
113115
# define the components
114116
pipe.add_component(Neo4jChunkReader(neo4j_driver), "reader")
115-
pipe.add_component(SchemaBuilder(), "schema")
117+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
116118
pipe.add_component(
117119
LLMEntityRelationExtractor(
118120
llm=llm,
@@ -142,22 +144,22 @@ async def read_chunk_and_perform_entity_extraction(
142144
NodeType(
143145
label="Person",
144146
properties=[
145-
PropertyType(name="name", type="STRING"),
146-
PropertyType(name="place_of_birth", type="STRING"),
147-
PropertyType(name="date_of_birth", type="DATE"),
147+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
148+
PropertyType(name="place_of_birth", type=Neo4jPropertyType.STRING),
149+
PropertyType(name="date_of_birth", type=Neo4jPropertyType.DATE),
148150
],
149151
),
150152
NodeType(
151153
label="Organization",
152154
properties=[
153-
PropertyType(name="name", type="STRING"),
154-
PropertyType(name="country", type="STRING"),
155+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
156+
PropertyType(name="country", type=Neo4jPropertyType.STRING),
155157
],
156158
),
157159
NodeType(
158160
label="Field",
159161
properties=[
160-
PropertyType(name="name", type="STRING"),
162+
PropertyType(name="name", type=Neo4jPropertyType.STRING),
161163
],
162164
),
163165
],

examples/kg_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
3333
from neo4j_graphrag.experimental.components.schema import (
3434
SchemaBuilder,
35+
)
36+
from neo4j_graphrag.experimental.components.types import (
3537
NodeType,
3638
RelationshipType,
3739
)
@@ -91,7 +93,7 @@ async def define_and_run_pipeline(
9193
pipe.add_component(
9294
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
9395
)
94-
pipe.add_component(SchemaBuilder(), "schema")
96+
pipe.add_component(SchemaBuilder(driver=neo4j_driver), "schema")
9597
pipe.add_component(
9698
LLMEntityRelationExtractor(
9799
llm=llm,

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import copy
1818
import json
1919
import logging
20-
from typing import Any, Dict, List, Optional, Tuple, Sequence, Literal
20+
from typing import Any, Dict, List, Optional, Tuple, Sequence
2121

2222
import neo4j
2323
from pydantic import (
@@ -38,10 +38,8 @@
3838
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
3939
from neo4j_graphrag.llm import LLMInterface
4040
from neo4j_graphrag.experimental.components.types import (
41-
RelationshipType,
4241
GraphSchema,
4342
SchemaConstraint,
44-
ConstraintTypeEnum,
4543
Neo4jConstraintTypeEnum,
4644
GraphEntityType,
4745
Neo4jPropertyType,
@@ -122,7 +120,6 @@ def _get_constraints_from_db(self) -> list[dict[str, Any]]:
122120

123121
def _apply_all_constraints_from_db(
124122
self,
125-
node_or_relationship_type: Literal["NODE", "RELATIONSHIP"],
126123
constraints: list[dict[str, Any]],
127124
entities: tuple[GraphEntityType, ...],
128125
) -> list[GraphEntityType]:
@@ -131,7 +128,7 @@ def _apply_all_constraints_from_db(
131128
new_entity_type = copy.deepcopy(entity_type)
132129
# find constraints related to this node type
133130
for constraint in constraints:
134-
if constraint["entityType"] != node_or_relationship_type:
131+
if constraint["entityType"] != entity_type._name:
135132
continue
136133
if constraint["labelsOrTypes"][0] != entity_type.label:
137134
continue
@@ -157,26 +154,27 @@ def _apply_constraint_from_db(
157154
self, entity_type: GraphEntityType, constraint: dict[str, Any]
158155
) -> None:
159156
neo4j_constraint_type = Neo4jConstraintTypeEnum(constraint["type"])
160-
# TODO: detect potential conflict and raise ValueError if any
161-
# existing_schema_constraints_on_property = node_type.get_constraints_on_properties(constraint["properties"])
162157
constraint_properties = constraint["properties"]
163158
for p in constraint_properties:
164159
if entity_type.get_property_by_name(p) is None:
165160
raise ValueError(
166161
f"Can not add constraint {constraint} on non existing property"
167162
)
168-
constraint_type = neo4j_constraint_type.to_constraint_type()
169-
entity_type.constraints.append(
170-
SchemaConstraint(
171-
type=constraint_type,
172-
properties=constraint["properties"],
173-
property_type=self._parse_property_type(constraint["propertyType"]),
174-
name=constraint["name"],
175-
)
163+
schema_constraint = SchemaConstraint(
164+
entity_type=entity_type._name,
165+
label_or_type=entity_type.label,
166+
type=neo4j_constraint_type,
167+
properties=constraint["properties"],
168+
property_type=self._parse_property_type(constraint["propertyType"]),
169+
name=constraint["name"],
176170
)
171+
177172
# if property required constraint, make sure the flag is set properly on
178173
# the PropertyType
179-
if constraint_type == ConstraintTypeEnum.PROPERTY_EXISTENCE:
174+
if schema_constraint.type in (
175+
Neo4jConstraintTypeEnum.NODE_PROPERTY_EXISTENCE,
176+
Neo4jConstraintTypeEnum.RELATIONSHIP_PROPERTY_EXISTENCE,
177+
):
180178
prop = entity_type.get_property_by_name(constraint["properties"][0])
181179
if prop:
182180
prop.required = True
@@ -185,7 +183,7 @@ def _apply_constraint_from_db(
185183
def _create_schema_model(
186184
self,
187185
node_types: Sequence[EntityInputType],
188-
relationship_types: Optional[Sequence[RelationshipType]] = None,
186+
relationship_types: Optional[Sequence[RelationInputType]] = None,
189187
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
190188
**kwargs: Any,
191189
) -> GraphSchema:
@@ -217,12 +215,10 @@ def _create_schema_model(
217215
constraints = self._get_constraints_from_db()
218216
# apply constraints
219217
constrained_node_types = self._apply_all_constraints_from_db(
220-
"NODE",
221218
constraints,
222219
schema.node_types,
223220
)
224221
constrained_relationship_types = self._apply_all_constraints_from_db(
225-
"RELATIONSHIP",
226222
constraints,
227223
schema.relationship_types,
228224
)

0 commit comments

Comments
 (0)