|
16 | 16 |
|
17 | 17 | import json
|
18 | 18 | import logging
|
19 |
| -import warnings |
20 |
| -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence |
21 |
| -from pathlib import Path |
| 19 | +from typing import Any, Dict, List, Optional, Tuple, Sequence |
22 | 20 |
|
23 | 21 | from pydantic import (
|
24 |
| - BaseModel, |
25 |
| - PrivateAttr, |
26 |
| - model_validator, |
27 | 22 | validate_call,
|
28 |
| - ConfigDict, |
29 | 23 | ValidationError,
|
30 | 24 | )
|
31 |
| -from typing_extensions import Self |
32 | 25 |
|
33 | 26 | from neo4j_graphrag.exceptions import (
|
34 | 27 | SchemaValidationError,
|
35 | 28 | LLMGenerationError,
|
36 | 29 | SchemaExtractionError,
|
37 | 30 | )
|
38 |
| -from neo4j_graphrag.experimental.pipeline.component import Component, DataModel |
39 |
| -from neo4j_graphrag.experimental.pipeline.types.schema import ( |
40 |
| - EntityInputType, |
41 |
| - RelationInputType, |
42 |
| -) |
| 31 | +from neo4j_graphrag.experimental.pipeline.component import Component |
43 | 32 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
|
44 | 33 | from neo4j_graphrag.llm import LLMInterface
|
45 |
| -from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat |
46 |
| - |
47 |
| - |
48 |
| -class PropertyType(BaseModel): |
49 |
| - """ |
50 |
| - Represents a property on a node or relationship in the graph. |
51 |
| - """ |
52 |
| - |
53 |
| - name: str |
54 |
| - # See https://neo4j.com/docs/cypher-manual/current/values-and-types/property-structural-constructed/#property-types |
55 |
| - type: Literal[ |
56 |
| - "BOOLEAN", |
57 |
| - "DATE", |
58 |
| - "DURATION", |
59 |
| - "FLOAT", |
60 |
| - "INTEGER", |
61 |
| - "LIST", |
62 |
| - "LOCAL_DATETIME", |
63 |
| - "LOCAL_TIME", |
64 |
| - "POINT", |
65 |
| - "STRING", |
66 |
| - "ZONED_DATETIME", |
67 |
| - "ZONED_TIME", |
68 |
| - ] |
69 |
| - description: str = "" |
70 |
| - required: bool = False |
71 |
| - |
72 |
| - model_config = ConfigDict( |
73 |
| - frozen=True, |
74 |
| - ) |
75 |
| - |
76 |
| - |
77 |
| -class NodeType(BaseModel): |
78 |
| - """ |
79 |
| - Represents a possible node in the graph. |
80 |
| - """ |
81 |
| - |
82 |
| - label: str |
83 |
| - description: str = "" |
84 |
| - properties: list[PropertyType] = [] |
85 |
| - additional_properties: bool = True |
86 |
| - |
87 |
| - @model_validator(mode="before") |
88 |
| - @classmethod |
89 |
| - def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: |
90 |
| - if isinstance(data, str): |
91 |
| - return {"label": data} |
92 |
| - return data |
93 |
| - |
94 |
| - @model_validator(mode="after") |
95 |
| - def validate_additional_properties(self) -> Self: |
96 |
| - if len(self.properties) == 0 and not self.additional_properties: |
97 |
| - raise ValueError( |
98 |
| - "Using `additional_properties=False` with no defined " |
99 |
| - "properties will cause the model to be pruned during graph cleaning.", |
100 |
| - ) |
101 |
| - return self |
102 |
| - |
103 |
| - |
104 |
| -class RelationshipType(BaseModel): |
105 |
| - """ |
106 |
| - Represents a possible relationship between nodes in the graph. |
107 |
| - """ |
108 |
| - |
109 |
| - label: str |
110 |
| - description: str = "" |
111 |
| - properties: list[PropertyType] = [] |
112 |
| - additional_properties: bool = True |
113 |
| - |
114 |
| - @model_validator(mode="before") |
115 |
| - @classmethod |
116 |
| - def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: |
117 |
| - if isinstance(data, str): |
118 |
| - return {"label": data} |
119 |
| - return data |
120 |
| - |
121 |
| - @model_validator(mode="after") |
122 |
| - def validate_additional_properties(self) -> Self: |
123 |
| - if len(self.properties) == 0 and not self.additional_properties: |
124 |
| - raise ValueError( |
125 |
| - "Using `additional_properties=False` with no defined " |
126 |
| - "properties will cause the model to be pruned during graph cleaning.", |
127 |
| - ) |
128 |
| - return self |
129 |
| - |
130 |
| - |
131 |
| -class GraphSchema(DataModel): |
132 |
| - """This model represents the expected |
133 |
| - node and relationship types in the graph. |
134 |
| -
|
135 |
| - It is used both for guiding the LLM in the entity and relation |
136 |
| - extraction component, and for cleaning the extracted graph in a |
137 |
| - post-processing step. |
138 |
| -
|
139 |
| - .. warning:: |
140 |
| -
|
141 |
| - This model is immutable. |
142 |
| - """ |
143 |
| - |
144 |
| - node_types: Tuple[NodeType, ...] |
145 |
| - relationship_types: Tuple[RelationshipType, ...] = tuple() |
146 |
| - patterns: Tuple[Tuple[str, str, str], ...] = tuple() |
147 |
| - |
148 |
| - additional_node_types: bool = True |
149 |
| - additional_relationship_types: bool = True |
150 |
| - additional_patterns: bool = True |
151 |
| - |
152 |
| - _node_type_index: dict[str, NodeType] = PrivateAttr() |
153 |
| - _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() |
154 |
| - |
155 |
| - model_config = ConfigDict( |
156 |
| - frozen=True, |
157 |
| - ) |
158 |
| - |
159 |
| - @model_validator(mode="after") |
160 |
| - def validate_patterns_against_node_and_rel_types(self) -> Self: |
161 |
| - self._node_type_index = {node.label: node for node in self.node_types} |
162 |
| - self._relationship_type_index = ( |
163 |
| - {r.label: r for r in self.relationship_types} |
164 |
| - if self.relationship_types |
165 |
| - else {} |
166 |
| - ) |
167 |
| - |
168 |
| - relationship_types = self.relationship_types |
169 |
| - patterns = self.patterns |
170 |
| - |
171 |
| - if patterns: |
172 |
| - if not relationship_types: |
173 |
| - raise SchemaValidationError( |
174 |
| - "Relationship types must also be provided when using patterns." |
175 |
| - ) |
176 |
| - for entity1, relation, entity2 in patterns: |
177 |
| - if entity1 not in self._node_type_index: |
178 |
| - raise SchemaValidationError( |
179 |
| - f"Node type '{entity1}' is not defined in the provided node_types." |
180 |
| - ) |
181 |
| - if relation not in self._relationship_type_index: |
182 |
| - raise SchemaValidationError( |
183 |
| - f"Relationship type '{relation}' is not defined in the provided relationship_types." |
184 |
| - ) |
185 |
| - if entity2 not in self._node_type_index: |
186 |
| - raise ValueError( |
187 |
| - f"Node type '{entity2}' is not defined in the provided node_types." |
188 |
| - ) |
189 |
| - |
190 |
| - return self |
191 |
| - |
192 |
| - @model_validator(mode="after") |
193 |
| - def validate_additional_parameters(self) -> Self: |
194 |
| - if ( |
195 |
| - self.additional_patterns is False |
196 |
| - and self.additional_relationship_types is True |
197 |
| - ): |
198 |
| - raise ValueError( |
199 |
| - "`additional_relationship_types` must be set to False when using `additional_patterns=False`" |
200 |
| - ) |
201 |
| - return self |
202 |
| - |
203 |
| - def node_type_from_label(self, label: str) -> Optional[NodeType]: |
204 |
| - return self._node_type_index.get(label) |
205 |
| - |
206 |
| - def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: |
207 |
| - return self._relationship_type_index.get(label) |
208 |
| - |
209 |
| - def save( |
210 |
| - self, |
211 |
| - file_path: Union[str, Path], |
212 |
| - overwrite: bool = False, |
213 |
| - format: Optional[FileFormat] = None, |
214 |
| - ) -> None: |
215 |
| - """ |
216 |
| - Save the schema configuration to file. |
217 |
| -
|
218 |
| - Args: |
219 |
| - file_path (str): The path where the schema configuration will be saved. |
220 |
| - overwrite (bool): If set to True, existing file will be overwritten. Default to False. |
221 |
| - format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. |
222 |
| - """ |
223 |
| - data = self.model_dump(mode="json") |
224 |
| - file_handler = FileHandler() |
225 |
| - file_handler.write(data, file_path, overwrite=overwrite, format=format) |
226 |
| - |
227 |
| - def store_as_json( |
228 |
| - self, file_path: Union[str, Path], overwrite: bool = False |
229 |
| - ) -> None: |
230 |
| - warnings.warn( |
231 |
| - "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning |
232 |
| - ) |
233 |
| - return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) |
234 |
| - |
235 |
| - def store_as_yaml( |
236 |
| - self, file_path: Union[str, Path], overwrite: bool = False |
237 |
| - ) -> None: |
238 |
| - warnings.warn( |
239 |
| - "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning |
240 |
| - ) |
241 |
| - return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) |
242 |
| - |
243 |
| - @classmethod |
244 |
| - def from_file( |
245 |
| - cls, file_path: Union[str, Path], format: Optional[FileFormat] = None |
246 |
| - ) -> Self: |
247 |
| - """ |
248 |
| - Load a schema configuration from a file (either JSON or YAML). |
249 |
| -
|
250 |
| - The file format is automatically detected based on the file extension, |
251 |
| - unless the format parameter is set. |
252 |
| -
|
253 |
| - Args: |
254 |
| - file_path (Union[str, Path]): The path to the schema configuration file. |
255 |
| - format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). |
256 |
| -
|
257 |
| - Returns: |
258 |
| - GraphSchema: The loaded schema configuration. |
259 |
| - """ |
260 |
| - file_path = Path(file_path) |
261 |
| - file_handler = FileHandler() |
262 |
| - try: |
263 |
| - data = file_handler.read(file_path, format=format) |
264 |
| - except ValueError: |
265 |
| - raise |
266 |
| - |
267 |
| - try: |
268 |
| - return cls.model_validate(data) |
269 |
| - except ValidationError as e: |
270 |
| - raise SchemaValidationError(str(e)) from e |
| 34 | +from neo4j_graphrag.experimental.components.types import ( |
| 35 | + NodeType, |
| 36 | + RelationshipType, |
| 37 | + GraphSchema, |
| 38 | +) |
271 | 39 |
|
272 | 40 |
|
273 | 41 | class SchemaBuilder(Component):
|
|
0 commit comments