Skip to content

Commit 987abf6

Browse files
authored
Use ConfigReader in SchemaConfig (neo4j#339)
* Use ConfigReader in SchemaConfig * Rename to FileHandler * Tests * Ruff * Add writer capability * Add docstrings to FileHandler * Introduce a 'format' parameter to reduce code duplication * Remove unused 'overwrite' argument * Fix tests * Fix tests
1 parent 18963b2 commit 987abf6

File tree

9 files changed

+594
-176
lines changed

9 files changed

+594
-176
lines changed

docs/source/user_guide_kg_builder.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,8 @@ You can also save and reload the extracted schema:
832832
.. code:: python
833833
834834
# Save the schema to JSON or YAML files
835-
extracted_schema.store_as_json("my_schema.json")
836-
extracted_schema.store_as_yaml("my_schema.yaml")
835+
extracted_schema.save("my_schema.json")
836+
extracted_schema.save("my_schema.yaml")
837837
838838
# Later, reload the schema from file
839839
from neo4j_graphrag.experimental.components.schema import GraphSchema

examples/customize/build_graph/components/schema_builders/schema_from_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ async def extract_and_save_schema() -> None:
8585

8686
print(f"Saving schema to JSON file: {JSON_FILE_PATH}")
8787
# Save the schema to JSON file
88-
inferred_schema.store_as_json(JSON_FILE_PATH)
88+
inferred_schema.save(JSON_FILE_PATH)
8989

9090
print(f"Saving schema to YAML file: {YAML_FILE_PATH}")
9191
# Save the schema to YAML file
92-
inferred_schema.store_as_yaml(YAML_FILE_PATH)
92+
inferred_schema.save(YAML_FILE_PATH)
9393

9494
print("\nExtracted Schema Summary:")
9595
print(f"Node types: {list(inferred_schema.node_types)}")

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 43 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from __future__ import annotations
1616

1717
import json
18-
import yaml
1918
import logging
19+
import warnings
2020
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
2121
from pathlib import Path
2222

@@ -42,6 +42,7 @@
4242
)
4343
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4444
from neo4j_graphrag.llm import LLMInterface
45+
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
4546

4647

4748
class PropertyType(BaseModel):
@@ -157,101 +158,68 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]:
157158
def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]:
158159
return self._relationship_type_index.get(label)
159160

160-
def store_as_json(self, file_path: str) -> None:
161+
def save(
162+
self,
163+
file_path: Union[str, Path],
164+
overwrite: bool = False,
165+
format: Optional[FileFormat] = None,
166+
) -> None:
161167
"""
162-
Save the schema configuration to a JSON file.
168+
Save the schema configuration to file.
163169
164170
Args:
165171
file_path (str): The path where the schema configuration will be saved.
172+
overwrite (bool): If set to True, existing file will be overwritten. Default to False.
173+
format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension.
166174
"""
167-
with open(file_path, "w") as f:
168-
json.dump(self.model_dump(), f, indent=2)
175+
data = self.model_dump(mode="json")
176+
file_handler = FileHandler()
177+
file_handler.write(data, file_path, overwrite=overwrite, format=format)
169178

170-
def store_as_yaml(self, file_path: str) -> None:
171-
"""
172-
Save the schema configuration to a YAML file.
179+
def store_as_json(
180+
self, file_path: Union[str, Path], overwrite: bool = False
181+
) -> None:
182+
warnings.warn(
183+
"Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning
184+
)
185+
return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON)
173186

174-
Args:
175-
file_path (str): The path where the schema configuration will be saved.
176-
"""
177-
# create a copy of the data and convert tuples to lists for YAML compatibility
178-
data = self.model_dump()
179-
if data.get("node_types"):
180-
data["node_types"] = list(data["node_types"])
181-
if data.get("relationship_types"):
182-
data["relationship_types"] = list(data["relationship_types"])
183-
if data.get("patterns"):
184-
data["patterns"] = [list(item) for item in data["patterns"]]
185-
186-
with open(file_path, "w") as f:
187-
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
187+
def store_as_yaml(
188+
self, file_path: Union[str, Path], overwrite: bool = False
189+
) -> None:
190+
warnings.warn(
191+
"Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning
192+
)
193+
return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML)
188194

189195
@classmethod
190-
def from_file(cls, file_path: Union[str, Path]) -> Self:
196+
def from_file(
197+
cls, file_path: Union[str, Path], format: Optional[FileFormat] = None
198+
) -> Self:
191199
"""
192200
Load a schema configuration from a file (either JSON or YAML).
193201
194-
The file format is automatically detected based on the file extension.
202+
The file format is automatically detected based on the file extension,
203+
unless the format parameter is set.
195204
196205
Args:
197206
file_path (Union[str, Path]): The path to the schema configuration file.
207+
format (Optional[FileFormat]): The format of the schema configuration file (json or yaml).
198208
199209
Returns:
200210
GraphSchema: The loaded schema configuration.
201211
"""
202212
file_path = Path(file_path)
213+
file_handler = FileHandler()
214+
try:
215+
data = file_handler.read(file_path, format=format)
216+
except ValueError:
217+
raise
203218

204-
if not file_path.exists():
205-
raise FileNotFoundError(f"Schema file not found: {file_path}")
206-
207-
if file_path.suffix.lower() in [".json"]:
208-
return cls.from_json(file_path)
209-
elif file_path.suffix.lower() in [".yaml", ".yml"]:
210-
return cls.from_yaml(file_path)
211-
else:
212-
raise ValueError(
213-
f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml"
214-
)
215-
216-
@classmethod
217-
def from_json(cls, file_path: Union[str, Path]) -> Self:
218-
"""
219-
Load a schema configuration from a JSON file.
220-
221-
Args:
222-
file_path (Union[str, Path]): The path to the JSON schema configuration file.
223-
224-
Returns:
225-
GraphSchema: The loaded schema configuration.
226-
"""
227-
with open(file_path, "r") as f:
228-
try:
229-
data = json.load(f)
230-
return cls.model_validate(data)
231-
except json.JSONDecodeError as e:
232-
raise ValueError(f"Invalid JSON file: {e}")
233-
except ValidationError as e:
234-
raise SchemaValidationError(f"Schema validation failed: {e}")
235-
236-
@classmethod
237-
def from_yaml(cls, file_path: Union[str, Path]) -> Self:
238-
"""
239-
Load a schema configuration from a YAML file.
240-
241-
Args:
242-
file_path (Union[str, Path]): The path to the YAML schema configuration file.
243-
244-
Returns:
245-
GraphSchema: The loaded schema configuration.
246-
"""
247-
with open(file_path, "r") as f:
248-
try:
249-
data = yaml.safe_load(f)
250-
return cls.model_validate(data)
251-
except yaml.YAMLError as e:
252-
raise ValueError(f"Invalid YAML file: {e}")
253-
except ValidationError as e:
254-
raise SchemaValidationError(f"Schema validation failed: {e}")
219+
try:
220+
return cls.model_validate(data)
221+
except ValidationError as e:
222+
raise SchemaValidationError(str(e)) from e
255223

256224

257225
class SchemaBuilder(Component):

src/neo4j_graphrag/experimental/pipeline/config/config_reader.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

src/neo4j_graphrag/experimental/pipeline/config/runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from typing_extensions import Self
3838

3939
from neo4j_graphrag.experimental.pipeline import Pipeline
40-
from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader
40+
from neo4j_graphrag.utils.file_handler import FileHandler
4141
from neo4j_graphrag.experimental.pipeline.config.pipeline_config import (
4242
AbstractPipelineConfig,
4343
PipelineConfig,
@@ -113,7 +113,7 @@ def from_config_file(cls, file_path: Union[str, Path]) -> Self:
113113
logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}")
114114
if not isinstance(file_path, str):
115115
file_path = str(file_path)
116-
data = ConfigReader().read(file_path)
116+
data = FileHandler().read(file_path)
117117
return cls.from_config(data, do_cleaning=True)
118118

119119
async def run(self, user_input: dict[str, Any]) -> PipelineResult:

0 commit comments

Comments
 (0)