diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index d7455a6f8..6af5e8cb3 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -832,8 +832,8 @@ You can also save and reload the extracted schema: .. code:: python # Save the schema to JSON or YAML files - extracted_schema.store_as_json("my_schema.json") - extracted_schema.store_as_yaml("my_schema.yaml") + extracted_schema.save("my_schema.json") + extracted_schema.save("my_schema.yaml") # Later, reload the schema from file from neo4j_graphrag.experimental.components.schema import GraphSchema diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index a36ef90ec..947907dff 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -85,11 +85,11 @@ async def extract_and_save_schema() -> None: print(f"Saving schema to JSON file: {JSON_FILE_PATH}") # Save the schema to JSON file - inferred_schema.store_as_json(JSON_FILE_PATH) + inferred_schema.save(JSON_FILE_PATH) print(f"Saving schema to YAML file: {YAML_FILE_PATH}") # Save the schema to YAML file - inferred_schema.store_as_yaml(YAML_FILE_PATH) + inferred_schema.save(YAML_FILE_PATH) print("\nExtracted Schema Summary:") print(f"Node types: {list(inferred_schema.node_types)}") diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index a4af30e5d..b2007ccb4 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -15,8 +15,8 @@ from __future__ import annotations import json -import yaml import logging +import warnings from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence from pathlib import Path @@ -42,6 +42,7 @@ ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat class PropertyType(BaseModel): @@ -157,101 +158,68 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]: def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: return self._relationship_type_index.get(label) - def store_as_json(self, file_path: str) -> None: + def save( + self, + file_path: Union[str, Path], + overwrite: bool = False, + format: Optional[FileFormat] = None, + ) -> None: """ - Save the schema configuration to a JSON file. + Save the schema configuration to file. Args: file_path (str): The path where the schema configuration will be saved. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. """ - with open(file_path, "w") as f: - json.dump(self.model_dump(), f, indent=2) + data = self.model_dump(mode="json") + file_handler = FileHandler() + file_handler.write(data, file_path, overwrite=overwrite, format=format) - def store_as_yaml(self, file_path: str) -> None: - """ - Save the schema configuration to a YAML file. + def store_as_json( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) - Args: - file_path (str): The path where the schema configuration will be saved. - """ - # create a copy of the data and convert tuples to lists for YAML compatibility - data = self.model_dump() - if data.get("node_types"): - data["node_types"] = list(data["node_types"]) - if data.get("relationship_types"): - data["relationship_types"] = list(data["relationship_types"]) - if data.get("patterns"): - data["patterns"] = [list(item) for item in data["patterns"]] - - with open(file_path, "w") as f: - yaml.dump(data, f, default_flow_style=False, sort_keys=False) + def store_as_yaml( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) @classmethod - def from_file(cls, file_path: Union[str, Path]) -> Self: + def from_file( + cls, file_path: Union[str, Path], format: Optional[FileFormat] = None + ) -> Self: """ Load a schema configuration from a file (either JSON or YAML). - The file format is automatically detected based on the file extension. + The file format is automatically detected based on the file extension, + unless the format parameter is set. Args: file_path (Union[str, Path]): The path to the schema configuration file. + format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). Returns: GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) + file_handler = FileHandler() + try: + data = file_handler.read(file_path, format=format) + except ValueError: + raise - if not file_path.exists(): - raise FileNotFoundError(f"Schema file not found: {file_path}") - - if file_path.suffix.lower() in [".json"]: - return cls.from_json(file_path) - elif file_path.suffix.lower() in [".yaml", ".yml"]: - return cls.from_yaml(file_path) - else: - raise ValueError( - f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" - ) - - @classmethod - def from_json(cls, file_path: Union[str, Path]) -> Self: - """ - Load a schema configuration from a JSON file. - - Args: - file_path (Union[str, Path]): The path to the JSON schema configuration file. - - Returns: - GraphSchema: The loaded schema configuration. - """ - with open(file_path, "r") as f: - try: - data = json.load(f) - return cls.model_validate(data) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON file: {e}") - except ValidationError as e: - raise SchemaValidationError(f"Schema validation failed: {e}") - - @classmethod - def from_yaml(cls, file_path: Union[str, Path]) -> Self: - """ - Load a schema configuration from a YAML file. - - Args: - file_path (Union[str, Path]): The path to the YAML schema configuration file. - - Returns: - GraphSchema: The loaded schema configuration. - """ - with open(file_path, "r") as f: - try: - data = yaml.safe_load(f) - return cls.model_validate(data) - except yaml.YAMLError as e: - raise ValueError(f"Invalid YAML file: {e}") - except ValidationError as e: - raise SchemaValidationError(f"Schema validation failed: {e}") + try: + return cls.model_validate(data) + except ValidationError as e: + raise SchemaValidationError(str(e)) from e class SchemaBuilder(Component): diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py deleted file mode 100644 index c3df28d9a..000000000 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Read JSON or YAML files and returns a dict. -No data validation performed at this stage. -""" - -import json -import logging -from pathlib import Path -from typing import Any, Optional - -import fsspec -import yaml -from fsspec.implementations.local import LocalFileSystem - -logger = logging.getLogger(__name__) - - -class ConfigReader: - """Reads config from a file (JSON or YAML format) - and returns a dict. - - File format is guessed from the extension. Supported extensions are - (lower or upper case): - - - .json - - .yaml, .yml - - Example: - - .. code-block:: python - - from pathlib import Path - from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader - reader = ConfigReader() - reader.read(Path("my_file.json")) - - If reading a file with a different extension but still in JSON or YAML format, - it is possible to call directly the `read_json` or `read_yaml` methods: - - .. code-block:: python - - reader.read_yaml(Path("my_file.txt")) - - """ - - def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: - self.fs = fs or LocalFileSystem() - - def read_json(self, file_path: str) -> Any: - logger.debug(f"CONFIG_READER: read from json {file_path}") - with self.fs.open(file_path, "r") as f: - return json.load(f) - - def read_yaml(self, file_path: str) -> Any: - logger.debug(f"CONFIG_READER: read from yaml {file_path}") - with self.fs.open(file_path, "r") as f: - return yaml.safe_load(f) - - def _guess_format_and_read(self, file_path: str) -> dict[str, Any]: - p = Path(file_path) - extension = p.suffix.lower() - # Note: .suffix returns an empty string if Path has no extension - # if not returning a dict, parsing will fail later on - if extension in [".json"]: - return self.read_json(file_path) # type: ignore[no-any-return] - if extension in [".yaml", ".yml"]: - return self.read_yaml(file_path) # type: ignore[no-any-return] - raise ValueError(f"Unsupported extension: {extension}") - - def read(self, file_path: str) -> dict[str, Any]: - data = self._guess_format_and_read(file_path) - return data diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index 82e8175a9..fb0544bce 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -37,7 +37,7 @@ from typing_extensions import Self from neo4j_graphrag.experimental.pipeline import Pipeline -from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader +from neo4j_graphrag.utils.file_handler import FileHandler from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, PipelineConfig, @@ -113,7 +113,7 @@ def from_config_file(cls, file_path: Union[str, Path]) -> Self: logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}") if not isinstance(file_path, str): file_path = str(file_path) - data = ConfigReader().read(file_path) + data = FileHandler().read(file_path) return cls.from_config(data, do_cleaning=True) async def run(self, user_input: dict[str, Any]) -> PipelineResult: diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py new file mode 100644 index 000000000..b351fe00a --- /dev/null +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -0,0 +1,239 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import json +import logging +from pathlib import Path +from typing import Any, Optional, Union + +import fsspec +import yaml +from fsspec.implementations.local import LocalFileSystem + +logger = logging.getLogger(__name__) + + +class FileFormat(enum.Enum): + JSON = "json" + YAML = "yaml" + + @classmethod + def json_valid_extension(cls) -> list[str]: + return [".json"] + + @classmethod + def yaml_valid_extension(cls) -> list[str]: + return [".yaml", ".yml"] + + +class FileHandler: + """Utility class to read/write JSON or YAML files. + + File format is guessed from the extension. Supported extensions are + (lower or upper case): + + - .json + - .yaml, .yml + + Example: + + .. code-block:: python + + from neo4j_graphrag.utils.file_handler import FileHandler + handler = FileHandler() + handler.read("my_file.json") + + If reading a file with a different extension but still in JSON or YAML format, + it is possible to call directly the `read_json` or `read_yaml` methods: + + .. code-block:: python + + handler.read_yaml("my_file.txt") + + """ + + def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: + self.fs = fs or LocalFileSystem() + + def _guess_file_format(self, path: Path) -> Optional[FileFormat]: + # Note: .suffix returns an empty string if Path has no extension + extension = path.suffix.lower() + if extension in FileFormat.json_valid_extension(): + return FileFormat.JSON + if extension in FileFormat.yaml_valid_extension(): + return FileFormat.YAML + return None + + def _check_file_exists(self, path: Path) -> Path: + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + return path + + def _read_json(self, path: Path) -> Any: + """Reads a JSON file. If file does not exist, raises FileNotFoundError. + + Args: + path (Path): The path of the JSON file. + + Raises: + FileNotFoundError: If file does not exist. + + Returns: + The parsed content of the JSON file. + """ + logger.debug(f"FILE_HANDLER: read from json {path}") + with self.fs.open(str(path), "r") as f: + try: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON file") from e + + def _read_yaml(self, path: Path) -> Any: + """Reads a YAML file. If file does not exist, raises FileNotFoundError. + + Args: + path (Path): The path of the YAML file. + + Raises: + FileNotFoundError: If file does not exist. + + Returns: + The parsed content of the YAML file. + """ + logger.debug(f"FILE_HANDLER: read from yaml {path}") + with self.fs.open(str(path), "r") as f: + try: + return yaml.safe_load(f) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML file") from e + + def read( + self, file_path: Union[Path, str], format: Optional[FileFormat] = None + ) -> Any: + """Try to infer file type from its extension and returns its + parsed content. + + Args: + file_path (Union[str, Path]): The path of the JSON file. + format (Optional[FileFormat]): The file format to infer the file type from. + If not set, the format is inferred from the extension. + + Raises: + FileNotFoundError: If file does not exist. + ValueError: If file extension is invalid or the file can not be parsed + (e.g. invalid JSON or YAML) + + Returns: the parsed content of the file. + """ + path = Path(file_path) + path = self._check_file_exists(path) + if not format: + format = self._guess_file_format(path) + if format == FileFormat.JSON: + return self._read_json(path) + if format == FileFormat.YAML: + return self._read_yaml(path) + raise ValueError(f"Unsupported file format: {format}") + + def _check_file_can_be_written(self, path: Path, overwrite: bool = False) -> None: + """Check whether the file can be written to path with the following conditions: + + - If overwrite is set to True, file can always be written, any existing file will be overwritten. + - If overwrite is set to False, and file already exists, file will not be overwritten. + + Args: + path (Path): The path of the target file. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + + Raises: + ValueError: If file can not be written according to the above rules. + """ + if overwrite: + # we can overwrite, so no matter if file already exists or not + return + try: + self._check_file_exists(path) + # did not raise, meaning the file exists + raise ValueError("File already exists. Use overwrite=True to overwrite.") + except FileNotFoundError: + # file not found all godo + pass + + def _write_json( + self, + data: Any, + file_path: Path, + **extra_kwargs: Any, + ) -> None: + """Writes data to a JSON file + + Args: + data (Any): The data to write. + file_path (Path): The path of the JSON file. + extra_kwargs (Any): Additional arguments passed to json.dump (e.g.: indent...). Note: a default indent=4 is applied. + + Raises: + ValueError: If file can not be written according to the above rules. + """ + fp = str(file_path) + kwargs: dict[str, Any] = { + "indent": 2, + } + kwargs.update(extra_kwargs) + with self.fs.open(fp, "w") as f: + json.dump(data, f, **kwargs) + + def _write_yaml( + self, + data: Any, + file_path: Path, + **extra_kwargs: Any, + ) -> None: + """Writes data to a YAML file + + Args: + data (Any): The data to write. + file_path (Path): The path of the YAML file. + extra_kwargs (Any): Additional arguments passed to yaml.safe_dump. Note that we apply the following defaults: + - "default_flow_style": False + - "sort_keys": True + """ + fp = str(file_path) + kwargs: dict[str, Any] = { + "default_flow_style": False, + "sort_keys": True, + } + kwargs.update(extra_kwargs) + with self.fs.open(fp, "w") as f: + yaml.safe_dump(data, f, **kwargs) + + def write( + self, + data: Any, + file_path: Union[Path, str], + overwrite: bool = False, + format: Optional[FileFormat] = None, + **extra_kwargs: Any, + ) -> None: + """Guess file type and write it.""" + path = Path(file_path) + self._check_file_can_be_written(path, overwrite) + if not format: + format = self._guess_file_format(path) + if format == FileFormat.JSON: + return self._write_json(data, path, **extra_kwargs) + if format == FileFormat.YAML: + return self._write_yaml(data, path, **extra_kwargs) + raise ValueError(f"Unsupported file format: {format}") diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index e8fc670c2..1aec51261 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -35,6 +35,7 @@ from neo4j_graphrag.generation import PromptTemplate from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.utils.file_handler import FileFormat @pytest.fixture @@ -383,13 +384,13 @@ async def test_schema_from_text_llm_params( @pytest.mark.asyncio -async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: +async def test_schema_config_save_json(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") # store the schema config - graph_schema.store_as_json(json_path) + graph_schema.save(json_path) # verify the file exists and has content assert os.path.exists(json_path) @@ -403,13 +404,13 @@ async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(graph_schema: GraphSchema) -> None: +async def test_schema_config_save_yaml(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") # Store the schema config - graph_schema.store_as_yaml(yaml_path) + graph_schema.save(yaml_path) # Verify the file exists and has content assert os.path.exists(yaml_path) @@ -431,9 +432,9 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: yml_path = os.path.join(temp_dir, "schema.yml") # store the schema config in the different formats - graph_schema.store_as_json(json_path) - graph_schema.store_as_yaml(yaml_path) - graph_schema.store_as_yaml(yml_path) + graph_schema.save(json_path) + graph_schema.save(yaml_path) + graph_schema.save(yml_path) # load using from_file which should detect the format based on extension json_schema = GraphSchema.from_file(json_path) @@ -452,9 +453,11 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") - graph_schema.store_as_json(txt_path) # Store as JSON but with .txt extension + graph_schema.save( + txt_path, format=FileFormat.JSON + ) # Store as JSON but with .txt extension - with pytest.raises(ValueError, match="Unsupported file format"): + with pytest.raises(ValueError, match="Unsupported file format: None"): GraphSchema.from_file(txt_path) diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py index 620796cb1..c20c4cf48 100644 --- a/tests/unit/experimental/pipeline/config/test_runner.py +++ b/tests/unit/experimental/pipeline/config/test_runner.py @@ -45,7 +45,7 @@ def test_pipeline_runner_from_config() -> None: @patch("neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner.from_config") -@patch("neo4j_graphrag.experimental.pipeline.config.config_reader.ConfigReader.read") +@patch("neo4j_graphrag.utils.file_handler.FileHandler.read") def test_pipeline_runner_from_config_file( mock_read: Mock, mock_from_config: Mock ) -> None: diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py new file mode 100644 index 000000000..5d12c05cb --- /dev/null +++ b/tests/unit/utils/test_file_handler.py @@ -0,0 +1,293 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from unittest.mock import patch, Mock, mock_open + +import pytest + +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat + + +def test_file_handler_guess_format() -> None: + handler = FileHandler() + + assert handler._guess_file_format(Path("file.json")) == FileFormat.JSON + assert handler._guess_file_format(Path("file.JSON")) == FileFormat.JSON + assert handler._guess_file_format(Path("file.yaml")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.YAML")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.yml")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.YML")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.txt")) is None + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._read_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_read_json_from_read_method_happy_path( + mock_file_exists: Mock, + mock_read_json: Mock, +) -> None: + handler = FileHandler() + + mock_file_exists.return_value = Path("file.json") + mock_read_json.return_value = {} + data = handler.read("file.json") + mock_read_json.assert_called_with(Path("file.json")) + assert data == {} + + mock_file_exists.return_value = Path("file.JSON") + mock_read_json.return_value = {} + data = handler.read("file.JSON") + mock_read_json.assert_called_with(Path("file.JSON")) + assert data == {} + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._read_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_read_yaml_from_read_method_happy_path( + mock_file_exists: Mock, + mock_read_yaml: Mock, +) -> None: + mock_file_exists.return_value = Path("file.yaml") + handler = FileHandler() + mock_read_yaml.return_value = {} + data = handler.read("file.yaml") + mock_read_yaml.assert_called_with(Path("file.yaml")) + assert data == {} + + mock_file_exists.return_value = Path("file.yml") + mock_read_yaml.return_value = {} + data = handler.read("file.yml") + mock_read_yaml.assert_called_with(Path("file.yml")) + assert data == {} + + mock_file_exists.return_value = Path("file.YAML") + mock_read_yaml.return_value = {} + data = handler.read("file.YAML") + mock_read_yaml.assert_called_with(Path("file.YAML")) + assert data == {} + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_method_happy_path( + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open(read_data='{"data": 1}') + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + data = handler._read_json(Path("file.json")) + mock_fs_open.assert_called_once_with("file.json", "r") + assert data == {"data": 1} + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_method_happy_path( + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open( + read_data=""" + data: 1 + """ + ) + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + data = handler._read_yaml(Path("file.yaml")) + mock_fs_open.assert_called_once_with("file.yaml", "r") + assert data == {"data": 1} + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_file_does_not_exist( + mock_fs: Mock, +) -> None: + mock_fs.return_value.open.side_effect = FileNotFoundError + handler = FileHandler() + with pytest.raises(FileNotFoundError): + handler._read_json(Path("file.json")) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_invalid_json( + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open(read_data="{") + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + with pytest.raises(ValueError, match="Invalid JSON"): + handler._read_json(Path("file.json")) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_file_does_not_exist( + mock_fs: Mock, +) -> None: + mock_fs.return_value.open.side_effect = FileNotFoundError + handler = FileHandler() + with pytest.raises(FileNotFoundError): + handler._read_yaml(Path("file.yaml")) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_invalid_yaml( + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open( + read_data=""" + data: [ + """ + ) + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + with pytest.raises(ValueError, match="Invalid YAML"): + handler._read_yaml(Path("file.yaml")) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_file_can_be_written_file(mock_file_exists: Mock) -> None: + # file does not exist + mock_file_exists.side_effect = FileNotFoundError() + handler = FileHandler() + # nothing happens = all good + handler._check_file_can_be_written(Path("file.json")) + # file exist, overwrite is False + mock_file_exists.side_effect = None + handler = FileHandler() + with pytest.raises(ValueError): + handler._check_file_can_be_written(Path("file.json"), overwrite=False) + # file exists, overwrite is True + mock_file_exists.side_effect = None + handler = FileHandler() + # nothing happens = all good + handler._check_file_can_be_written(Path("file.json"), overwrite=True) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.json") +def test_file_handler_write_json_happy_path( + mock_json_module: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler._write_json({"some": "data"}, Path("file.json")) + mock_fs_open.assert_called_once_with("file.json", "w") + mock_json_module.dump.assert_called_with( + {"some": "data"}, mock_fs_open.return_value, indent=2 + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.json") +def test_file_handler_write_json_extra_kwargs_happy_path( + mock_json_module: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler._write_json({"some": "data"}, Path("file.json"), indent=4, default=str) + mock_fs_open.assert_called_once_with("file.json", "w") + mock_json_module.dump.assert_called_with( + {"some": "data"}, mock_fs_open.return_value, indent=4, default=str + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.yaml") +def test_file_handler_write_yaml_happy_path( + mock_yaml_module: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler._write_yaml({"some": "data"}, Path("file.yaml")) + mock_fs_open.assert_called_once_with("file.yaml", "w") + mock_yaml_module.safe_dump.assert_called_with( + {"some": "data"}, + mock_fs_open.return_value, + default_flow_style=False, + sort_keys=True, + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.yaml") +def test_file_handler_write_yaml_extra_kwargs_happy_path( + mock_yaml_module: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler._write_yaml( + {"some": "data"}, Path("file.json"), default_flow_style="toto", other_keyword=42 + ) + mock_fs_open.assert_called_once_with("file.json", "w") + mock_yaml_module.safe_dump.assert_called_with( + {"some": "data"}, + mock_fs_open.return_value, + default_flow_style="toto", + sort_keys=True, + other_keyword=42, + ) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_json_from_write_method_happy_path( + mock_file_can_be_written: Mock, + mock_write_json: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.json") + mock_write_json.assert_called_with("data", Path("file.json")) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_yaml_from_write_method_happy_path( + mock_file_can_be_written: Mock, + mock_write_yaml: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.yaml") + mock_write_yaml.assert_called_with("data", Path("file.yaml")) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_yaml_from_write_method_overwrite_format_happy_path( + mock_file_can_be_written: Mock, + mock_write_yaml: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.txt", format=FileFormat.YAML) + mock_write_yaml.assert_called_with("data", Path("file.txt")) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_json_from_write_method_overwrite_format_happy_path( + mock_file_can_be_written: Mock, + mock_write_json: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.txt", format=FileFormat.JSON) + mock_write_json.assert_called_with("data", Path("file.txt"))