Skip to content

Commit 2fc620f

Browse files
committed
Add Constraint to GraphSchema - move types in the types.py file
1 parent 9dc9719 commit 2fc620f

File tree

8 files changed

+297
-251
lines changed

8 files changed

+297
-251
lines changed

examples/customize/build_graph/components/pruners/graph_pruner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import asyncio
44

55
from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning
6-
from neo4j_graphrag.experimental.components.schema import (
6+
from neo4j_graphrag.experimental.components.types import (
77
GraphSchema,
88
NodeType,
99
PropertyType,
1010
RelationshipType,
11-
)
12-
from neo4j_graphrag.experimental.components.types import (
1311
Neo4jGraph,
1412
Neo4jNode,
1513
Neo4jRelationship,

examples/customize/build_graph/components/schema_builders/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
from neo4j_graphrag.experimental.components.schema import (
1616
SchemaBuilder,
17+
)
18+
from neo4j_graphrag.experimental.components.types import (
1719
NodeType,
1820
PropertyType,
1921
RelationshipType,

examples/customize/build_graph/components/schema_builders/schema_from_text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from neo4j_graphrag.experimental.components.schema import (
1616
SchemaFromTextExtractor,
17+
)
18+
from neo4j_graphrag.experimental.components.types import (
1719
GraphSchema,
1820
)
1921
from neo4j_graphrag.llm import OpenAILLM

src/neo4j_graphrag/experimental/components/graph_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pydantic import validate_call, BaseModel
2020

21-
from neo4j_graphrag.experimental.components.schema import (
21+
from neo4j_graphrag.experimental.components.types import (
2222
GraphSchema,
2323
PropertyType,
2424
NodeType,

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 7 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -16,258 +16,26 @@
1616

1717
import json
1818
import logging
19-
import warnings
20-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
21-
from pathlib import Path
19+
from typing import Any, Dict, List, Optional, Tuple, Sequence
2220

2321
from pydantic import (
24-
BaseModel,
25-
PrivateAttr,
26-
model_validator,
2722
validate_call,
28-
ConfigDict,
2923
ValidationError,
3024
)
31-
from typing_extensions import Self
3225

3326
from neo4j_graphrag.exceptions import (
3427
SchemaValidationError,
3528
LLMGenerationError,
3629
SchemaExtractionError,
3730
)
38-
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
39-
from neo4j_graphrag.experimental.pipeline.types.schema import (
40-
EntityInputType,
41-
RelationInputType,
42-
)
31+
from neo4j_graphrag.experimental.pipeline.component import Component
4332
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4433
from neo4j_graphrag.llm import LLMInterface
45-
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
46-
47-
48-
class PropertyType(BaseModel):
49-
"""
50-
Represents a property on a node or relationship in the graph.
51-
"""
52-
53-
name: str
54-
# See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types
55-
type: Literal[
56-
"BOOLEAN",
57-
"DATE",
58-
"DURATION",
59-
"FLOAT",
60-
"INTEGER",
61-
"LIST",
62-
"LOCAL_DATETIME",
63-
"LOCAL_TIME",
64-
"POINT",
65-
"STRING",
66-
"ZONED_DATETIME",
67-
"ZONED_TIME",
68-
]
69-
description: str = ""
70-
required: bool = False
71-
72-
model_config = ConfigDict(
73-
frozen=True,
74-
)
75-
76-
77-
class NodeType(BaseModel):
78-
"""
79-
Represents a possible node in the graph.
80-
"""
81-
82-
label: str
83-
description: str = ""
84-
properties: list[PropertyType] = []
85-
additional_properties: bool = True
86-
87-
@model_validator(mode="before")
88-
@classmethod
89-
def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType:
90-
if isinstance(data, str):
91-
return {"label": data}
92-
return data
93-
94-
@model_validator(mode="after")
95-
def validate_additional_properties(self) -> Self:
96-
if len(self.properties) == 0 and not self.additional_properties:
97-
raise ValueError(
98-
"Using `additional_properties=False` with no defined "
99-
"properties will cause the model to be pruned during graph cleaning.",
100-
)
101-
return self
102-
103-
104-
class RelationshipType(BaseModel):
105-
"""
106-
Represents a possible relationship between nodes in the graph.
107-
"""
108-
109-
label: str
110-
description: str = ""
111-
properties: list[PropertyType] = []
112-
additional_properties: bool = True
113-
114-
@model_validator(mode="before")
115-
@classmethod
116-
def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType:
117-
if isinstance(data, str):
118-
return {"label": data}
119-
return data
120-
121-
@model_validator(mode="after")
122-
def validate_additional_properties(self) -> Self:
123-
if len(self.properties) == 0 and not self.additional_properties:
124-
raise ValueError(
125-
"Using `additional_properties=False` with no defined "
126-
"properties will cause the model to be pruned during graph cleaning.",
127-
)
128-
return self
129-
130-
131-
class GraphSchema(DataModel):
132-
"""This model represents the expected
133-
node and relationship types in the graph.
134-
135-
It is used both for guiding the LLM in the entity and relation
136-
extraction component, and for cleaning the extracted graph in a
137-
post-processing step.
138-
139-
.. warning::
140-
141-
This model is immutable.
142-
"""
143-
144-
node_types: Tuple[NodeType, ...]
145-
relationship_types: Tuple[RelationshipType, ...] = tuple()
146-
patterns: Tuple[Tuple[str, str, str], ...] = tuple()
147-
148-
additional_node_types: bool = True
149-
additional_relationship_types: bool = True
150-
additional_patterns: bool = True
151-
152-
_node_type_index: dict[str, NodeType] = PrivateAttr()
153-
_relationship_type_index: dict[str, RelationshipType] = PrivateAttr()
154-
155-
model_config = ConfigDict(
156-
frozen=True,
157-
)
158-
159-
@model_validator(mode="after")
160-
def validate_patterns_against_node_and_rel_types(self) -> Self:
161-
self._node_type_index = {node.label: node for node in self.node_types}
162-
self._relationship_type_index = (
163-
{r.label: r for r in self.relationship_types}
164-
if self.relationship_types
165-
else {}
166-
)
167-
168-
relationship_types = self.relationship_types
169-
patterns = self.patterns
170-
171-
if patterns:
172-
if not relationship_types:
173-
raise SchemaValidationError(
174-
"Relationship types must also be provided when using patterns."
175-
)
176-
for entity1, relation, entity2 in patterns:
177-
if entity1 not in self._node_type_index:
178-
raise SchemaValidationError(
179-
f"Node type '{entity1}' is not defined in the provided node_types."
180-
)
181-
if relation not in self._relationship_type_index:
182-
raise SchemaValidationError(
183-
f"Relationship type '{relation}' is not defined in the provided relationship_types."
184-
)
185-
if entity2 not in self._node_type_index:
186-
raise ValueError(
187-
f"Node type '{entity2}' is not defined in the provided node_types."
188-
)
189-
190-
return self
191-
192-
@model_validator(mode="after")
193-
def validate_additional_parameters(self) -> Self:
194-
if (
195-
self.additional_patterns is False
196-
and self.additional_relationship_types is True
197-
):
198-
raise ValueError(
199-
"`additional_relationship_types` must be set to False when using `additional_patterns=False`"
200-
)
201-
return self
202-
203-
def node_type_from_label(self, label: str) -> Optional[NodeType]:
204-
return self._node_type_index.get(label)
205-
206-
def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]:
207-
return self._relationship_type_index.get(label)
208-
209-
def save(
210-
self,
211-
file_path: Union[str, Path],
212-
overwrite: bool = False,
213-
format: Optional[FileFormat] = None,
214-
) -> None:
215-
"""
216-
Save the schema configuration to file.
217-
218-
Args:
219-
file_path (str): The path where the schema configuration will be saved.
220-
overwrite (bool): If set to True, existing file will be overwritten. Default to False.
221-
format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension.
222-
"""
223-
data = self.model_dump(mode="json")
224-
file_handler = FileHandler()
225-
file_handler.write(data, file_path, overwrite=overwrite, format=format)
226-
227-
def store_as_json(
228-
self, file_path: Union[str, Path], overwrite: bool = False
229-
) -> None:
230-
warnings.warn(
231-
"Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning
232-
)
233-
return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON)
234-
235-
def store_as_yaml(
236-
self, file_path: Union[str, Path], overwrite: bool = False
237-
) -> None:
238-
warnings.warn(
239-
"Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning
240-
)
241-
return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML)
242-
243-
@classmethod
244-
def from_file(
245-
cls, file_path: Union[str, Path], format: Optional[FileFormat] = None
246-
) -> Self:
247-
"""
248-
Load a schema configuration from a file (either JSON or YAML).
249-
250-
The file format is automatically detected based on the file extension,
251-
unless the format parameter is set.
252-
253-
Args:
254-
file_path (Union[str, Path]): The path to the schema configuration file.
255-
format (Optional[FileFormat]): The format of the schema configuration file (json or yaml).
256-
257-
Returns:
258-
GraphSchema: The loaded schema configuration.
259-
"""
260-
file_path = Path(file_path)
261-
file_handler = FileHandler()
262-
try:
263-
data = file_handler.read(file_path, format=format)
264-
except ValueError:
265-
raise
266-
267-
try:
268-
return cls.model_validate(data)
269-
except ValidationError as e:
270-
raise SchemaValidationError(str(e)) from e
34+
from neo4j_graphrag.experimental.components.types import (
35+
NodeType,
36+
RelationshipType,
37+
GraphSchema,
38+
)
27139

27240

27341
class SchemaBuilder(Component):

0 commit comments

Comments
 (0)