|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import json
|
18 |
| -import yaml |
19 | 18 | import logging
|
| 19 | +import warnings |
20 | 20 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
|
21 | 21 | from pathlib import Path
|
22 | 22 |
|
|
42 | 42 | )
|
43 | 43 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
|
44 | 44 | from neo4j_graphrag.llm import LLMInterface
|
| 45 | +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat |
45 | 46 |
|
46 | 47 |
|
47 | 48 | class PropertyType(BaseModel):
|
@@ -157,101 +158,68 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]:
|
157 | 158 | def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]:
|
158 | 159 | return self._relationship_type_index.get(label)
|
159 | 160 |
|
160 |
| - def store_as_json(self, file_path: str) -> None: |
| 161 | + def save( |
| 162 | + self, |
| 163 | + file_path: Union[str, Path], |
| 164 | + overwrite: bool = False, |
| 165 | + format: Optional[FileFormat] = None, |
| 166 | + ) -> None: |
161 | 167 | """
|
162 |
| - Save the schema configuration to a JSON file. |
| 168 | + Save the schema configuration to file. |
163 | 169 |
|
164 | 170 | Args:
|
165 | 171 | file_path (str): The path where the schema configuration will be saved.
|
| 172 | + overwrite (bool): If set to True, existing file will be overwritten. Default to False. |
| 173 | + format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. |
166 | 174 | """
|
167 |
| - with open(file_path, "w") as f: |
168 |
| - json.dump(self.model_dump(), f, indent=2) |
| 175 | + data = self.model_dump(mode="json") |
| 176 | + file_handler = FileHandler() |
| 177 | + file_handler.write(data, file_path, overwrite=overwrite, format=format) |
169 | 178 |
|
170 |
| - def store_as_yaml(self, file_path: str) -> None: |
171 |
| - """ |
172 |
| - Save the schema configuration to a YAML file. |
| 179 | + def store_as_json( |
| 180 | + self, file_path: Union[str, Path], overwrite: bool = False |
| 181 | + ) -> None: |
| 182 | + warnings.warn( |
| 183 | + "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning |
| 184 | + ) |
| 185 | + return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) |
173 | 186 |
|
174 |
| - Args: |
175 |
| - file_path (str): The path where the schema configuration will be saved. |
176 |
| - """ |
177 |
| - # create a copy of the data and convert tuples to lists for YAML compatibility |
178 |
| - data = self.model_dump() |
179 |
| - if data.get("node_types"): |
180 |
| - data["node_types"] = list(data["node_types"]) |
181 |
| - if data.get("relationship_types"): |
182 |
| - data["relationship_types"] = list(data["relationship_types"]) |
183 |
| - if data.get("patterns"): |
184 |
| - data["patterns"] = [list(item) for item in data["patterns"]] |
185 |
| - |
186 |
| - with open(file_path, "w") as f: |
187 |
| - yaml.dump(data, f, default_flow_style=False, sort_keys=False) |
| 187 | + def store_as_yaml( |
| 188 | + self, file_path: Union[str, Path], overwrite: bool = False |
| 189 | + ) -> None: |
| 190 | + warnings.warn( |
| 191 | + "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning |
| 192 | + ) |
| 193 | + return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) |
188 | 194 |
|
189 | 195 | @classmethod
|
190 |
| - def from_file(cls, file_path: Union[str, Path]) -> Self: |
| 196 | + def from_file( |
| 197 | + cls, file_path: Union[str, Path], format: Optional[FileFormat] = None |
| 198 | + ) -> Self: |
191 | 199 | """
|
192 | 200 | Load a schema configuration from a file (either JSON or YAML).
|
193 | 201 |
|
194 |
| - The file format is automatically detected based on the file extension. |
| 202 | + The file format is automatically detected based on the file extension, |
| 203 | + unless the format parameter is set. |
195 | 204 |
|
196 | 205 | Args:
|
197 | 206 | file_path (Union[str, Path]): The path to the schema configuration file.
|
| 207 | + format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). |
198 | 208 |
|
199 | 209 | Returns:
|
200 | 210 | GraphSchema: The loaded schema configuration.
|
201 | 211 | """
|
202 | 212 | file_path = Path(file_path)
|
| 213 | + file_handler = FileHandler() |
| 214 | + try: |
| 215 | + data = file_handler.read(file_path, format=format) |
| 216 | + except ValueError: |
| 217 | + raise |
203 | 218 |
|
204 |
| - if not file_path.exists(): |
205 |
| - raise FileNotFoundError(f"Schema file not found: {file_path}") |
206 |
| - |
207 |
| - if file_path.suffix.lower() in [".json"]: |
208 |
| - return cls.from_json(file_path) |
209 |
| - elif file_path.suffix.lower() in [".yaml", ".yml"]: |
210 |
| - return cls.from_yaml(file_path) |
211 |
| - else: |
212 |
| - raise ValueError( |
213 |
| - f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" |
214 |
| - ) |
215 |
| - |
216 |
| - @classmethod |
217 |
| - def from_json(cls, file_path: Union[str, Path]) -> Self: |
218 |
| - """ |
219 |
| - Load a schema configuration from a JSON file. |
220 |
| -
|
221 |
| - Args: |
222 |
| - file_path (Union[str, Path]): The path to the JSON schema configuration file. |
223 |
| -
|
224 |
| - Returns: |
225 |
| - GraphSchema: The loaded schema configuration. |
226 |
| - """ |
227 |
| - with open(file_path, "r") as f: |
228 |
| - try: |
229 |
| - data = json.load(f) |
230 |
| - return cls.model_validate(data) |
231 |
| - except json.JSONDecodeError as e: |
232 |
| - raise ValueError(f"Invalid JSON file: {e}") |
233 |
| - except ValidationError as e: |
234 |
| - raise SchemaValidationError(f"Schema validation failed: {e}") |
235 |
| - |
236 |
| - @classmethod |
237 |
| - def from_yaml(cls, file_path: Union[str, Path]) -> Self: |
238 |
| - """ |
239 |
| - Load a schema configuration from a YAML file. |
240 |
| -
|
241 |
| - Args: |
242 |
| - file_path (Union[str, Path]): The path to the YAML schema configuration file. |
243 |
| -
|
244 |
| - Returns: |
245 |
| - GraphSchema: The loaded schema configuration. |
246 |
| - """ |
247 |
| - with open(file_path, "r") as f: |
248 |
| - try: |
249 |
| - data = yaml.safe_load(f) |
250 |
| - return cls.model_validate(data) |
251 |
| - except yaml.YAMLError as e: |
252 |
| - raise ValueError(f"Invalid YAML file: {e}") |
253 |
| - except ValidationError as e: |
254 |
| - raise SchemaValidationError(f"Schema validation failed: {e}") |
| 219 | + try: |
| 220 | + return cls.model_validate(data) |
| 221 | + except ValidationError as e: |
| 222 | + raise SchemaValidationError(str(e)) from e |
255 | 223 |
|
256 | 224 |
|
257 | 225 | class SchemaBuilder(Component):
|
|
0 commit comments