From 030ac64a3a9f7d6f04dff6607367bd00c94af3ad Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 13:40:37 +0200 Subject: [PATCH 01/10] Use ConfigReader in SchemaConfig --- .../experimental/components/schema.py | 71 +++---------------- .../experimental/pipeline/config/runner.py | 4 +- .../config_reader.py => utils/file_reader.py} | 39 ++++++---- .../experimental/components/test_schema.py | 2 +- .../pipeline/config/test_runner.py | 2 +- 5 files changed, 39 insertions(+), 79 deletions(-) rename src/neo4j_graphrag/{experimental/pipeline/config/config_reader.py => utils/file_reader.py} (65%) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index a4af30e5d..8f962eb5f 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -42,6 +42,7 @@ ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface +from neo4j_graphrag.utils.file_reader import FileReader class PropertyType(BaseModel): @@ -174,15 +175,7 @@ def store_as_yaml(self, file_path: str) -> None: Args: file_path (str): The path where the schema configuration will be saved. """ - # create a copy of the data and convert tuples to lists for YAML compatibility - data = self.model_dump() - if data.get("node_types"): - data["node_types"] = list(data["node_types"]) - if data.get("relationship_types"): - data["relationship_types"] = list(data["relationship_types"]) - if data.get("patterns"): - data["patterns"] = [list(item) for item in data["patterns"]] - + data = self.model_dump(mode="json") with open(file_path, "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False) @@ -200,58 +193,16 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) + reader = FileReader() + try: + data = reader.read(file_path) + except ValueError: + raise - if not file_path.exists(): - raise FileNotFoundError(f"Schema file not found: {file_path}") - - if file_path.suffix.lower() in [".json"]: - return cls.from_json(file_path) - elif file_path.suffix.lower() in [".yaml", ".yml"]: - return cls.from_yaml(file_path) - else: - raise ValueError( - f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" - ) - - @classmethod - def from_json(cls, file_path: Union[str, Path]) -> Self: - """ - Load a schema configuration from a JSON file. - - Args: - file_path (Union[str, Path]): The path to the JSON schema configuration file. - - Returns: - GraphSchema: The loaded schema configuration. - """ - with open(file_path, "r") as f: - try: - data = json.load(f) - return cls.model_validate(data) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON file: {e}") - except ValidationError as e: - raise SchemaValidationError(f"Schema validation failed: {e}") - - @classmethod - def from_yaml(cls, file_path: Union[str, Path]) -> Self: - """ - Load a schema configuration from a YAML file. - - Args: - file_path (Union[str, Path]): The path to the YAML schema configuration file. - - Returns: - GraphSchema: The loaded schema configuration. - """ - with open(file_path, "r") as f: - try: - data = yaml.safe_load(f) - return cls.model_validate(data) - except yaml.YAMLError as e: - raise ValueError(f"Invalid YAML file: {e}") - except ValidationError as e: - raise SchemaValidationError(f"Schema validation failed: {e}") + try: + return cls.model_validate(data) + except ValidationError as e: + raise SchemaValidationError(str(e)) from e class SchemaBuilder(Component): diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index 82e8175a9..c656a6009 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -37,7 +37,7 @@ from typing_extensions import Self from neo4j_graphrag.experimental.pipeline import Pipeline -from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader +from neo4j_graphrag.utils.file_reader import FileReader from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, PipelineConfig, @@ -113,7 +113,7 @@ def from_config_file(cls, file_path: Union[str, Path]) -> Self: logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}") if not isinstance(file_path, str): file_path = str(file_path) - data = ConfigReader().read(file_path) + data = FileReader().read(file_path) return cls.from_config(data, do_cleaning=True) async def run(self, user_input: dict[str, Any]) -> PipelineResult: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py b/src/neo4j_graphrag/utils/file_reader.py similarity index 65% rename from src/neo4j_graphrag/experimental/pipeline/config/config_reader.py rename to src/neo4j_graphrag/utils/file_reader.py index c3df28d9a..9c8a6347c 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py +++ b/src/neo4j_graphrag/utils/file_reader.py @@ -19,7 +19,7 @@ import json import logging from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Union import fsspec import yaml @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -class ConfigReader: +class FileReader: """Reads config from a file (JSON or YAML format) and returns a dict. @@ -43,8 +43,8 @@ class ConfigReader: .. code-block:: python from pathlib import Path - from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader - reader = ConfigReader() + from neo4j_graphrag.utils.file_reader import FileReader + reader = FileReader() reader.read(Path("my_file.json")) If reading a file with a different extension but still in JSON or YAML format, @@ -60,26 +60,35 @@ def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: self.fs = fs or LocalFileSystem() def read_json(self, file_path: str) -> Any: - logger.debug(f"CONFIG_READER: read from json {file_path}") + logger.debug(f"FILE_READER: read from json {file_path}") with self.fs.open(file_path, "r") as f: - return json.load(f) + try: + return json.load(f) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON file") from e def read_yaml(self, file_path: str) -> Any: - logger.debug(f"CONFIG_READER: read from yaml {file_path}") + logger.debug(f"FILE_READER: read from yaml {file_path}") with self.fs.open(file_path, "r") as f: - return yaml.safe_load(f) - - def _guess_format_and_read(self, file_path: str) -> dict[str, Any]: - p = Path(file_path) - extension = p.suffix.lower() + try: + return yaml.safe_load(f) + except yaml.YAMLError as e: + raise ValueError("Invalid YAML file") from e + + def _guess_format_and_read(self, file_path: Union[Path, str]) -> dict[str, Any]: + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + extension = path.suffix.lower() # Note: .suffix returns an empty string if Path has no extension # if not returning a dict, parsing will fail later on + path_as_string = str(file_path) if extension in [".json"]: - return self.read_json(file_path) # type: ignore[no-any-return] + return self.read_json(path_as_string) # type: ignore[no-any-return] if extension in [".yaml", ".yml"]: - return self.read_yaml(file_path) # type: ignore[no-any-return] + return self.read_yaml(path_as_string) # type: ignore[no-any-return] raise ValueError(f"Unsupported extension: {extension}") - def read(self, file_path: str) -> dict[str, Any]: + def read(self, file_path: Union[Path, str]) -> dict[str, Any]: data = self._guess_format_and_read(file_path) return data diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index e8fc670c2..15e0d674c 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -454,7 +454,7 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: txt_path = os.path.join(temp_dir, "schema.txt") graph_schema.store_as_json(txt_path) # Store as JSON but with .txt extension - with pytest.raises(ValueError, match="Unsupported file format"): + with pytest.raises(ValueError, match="Unsupported extension: .txt"): GraphSchema.from_file(txt_path) diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py index 620796cb1..430462285 100644 --- a/tests/unit/experimental/pipeline/config/test_runner.py +++ b/tests/unit/experimental/pipeline/config/test_runner.py @@ -45,7 +45,7 @@ def test_pipeline_runner_from_config() -> None: @patch("neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner.from_config") -@patch("neo4j_graphrag.experimental.pipeline.config.config_reader.ConfigReader.read") +@patch("neo4j_graphrag.utils.file_reader.FileReader.read") def test_pipeline_runner_from_config_file( mock_read: Mock, mock_from_config: Mock ) -> None: From 3d57653b0c50cd4e5d5b089f12e9d172c6acc761 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 13:59:32 +0200 Subject: [PATCH 02/10] Rename to FileHandler --- .../experimental/components/schema.py | 4 ++-- .../experimental/pipeline/config/runner.py | 4 ++-- .../utils/{file_reader.py => file_handler.py} | 21 +++++++------------ .../pipeline/config/test_runner.py | 2 +- 4 files changed, 13 insertions(+), 18 deletions(-) rename src/neo4j_graphrag/utils/{file_reader.py => file_handler.py} (84%) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 8f962eb5f..0cf10b381 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -42,7 +42,7 @@ ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface -from neo4j_graphrag.utils.file_reader import FileReader +from neo4j_graphrag.utils.file_handler import FileHandler class PropertyType(BaseModel): @@ -193,7 +193,7 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) - reader = FileReader() + reader = FileHandler() try: data = reader.read(file_path) except ValueError: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index c656a6009..fb0544bce 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -37,7 +37,7 @@ from typing_extensions import Self from neo4j_graphrag.experimental.pipeline import Pipeline -from neo4j_graphrag.utils.file_reader import FileReader +from neo4j_graphrag.utils.file_handler import FileHandler from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, PipelineConfig, @@ -113,7 +113,7 @@ def from_config_file(cls, file_path: Union[str, Path]) -> Self: logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}") if not isinstance(file_path, str): file_path = str(file_path) - data = FileReader().read(file_path) + data = FileHandler().read(file_path) return cls.from_config(data, do_cleaning=True) async def run(self, user_input: dict[str, Any]) -> PipelineResult: diff --git a/src/neo4j_graphrag/utils/file_reader.py b/src/neo4j_graphrag/utils/file_handler.py similarity index 84% rename from src/neo4j_graphrag/utils/file_reader.py rename to src/neo4j_graphrag/utils/file_handler.py index 9c8a6347c..c8689bcab 100644 --- a/src/neo4j_graphrag/utils/file_reader.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -12,10 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Read JSON or YAML files and returns a dict. -No data validation performed at this stage. -""" - import json import logging from pathlib import Path @@ -28,9 +24,8 @@ logger = logging.getLogger(__name__) -class FileReader: - """Reads config from a file (JSON or YAML format) - and returns a dict. +class FileHandler: + """Utility class to read JSON or YAML files. File format is guessed from the extension. Supported extensions are (lower or upper case): @@ -43,16 +38,16 @@ class FileReader: .. code-block:: python from pathlib import Path - from neo4j_graphrag.utils.file_reader import FileReader - reader = FileReader() - reader.read(Path("my_file.json")) + from neo4j_graphrag.utils.file_handler import FileHandler + handler = FileHandler() + handler.read(Path("my_file.json")) If reading a file with a different extension but still in JSON or YAML format, it is possible to call directly the `read_json` or `read_yaml` methods: .. code-block:: python - reader.read_yaml(Path("my_file.txt")) + handler.read_yaml(Path("my_file.txt")) """ @@ -60,7 +55,7 @@ def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: self.fs = fs or LocalFileSystem() def read_json(self, file_path: str) -> Any: - logger.debug(f"FILE_READER: read from json {file_path}") + logger.debug(f"FILE_HANDLER: read from json {file_path}") with self.fs.open(file_path, "r") as f: try: return json.load(f) @@ -68,7 +63,7 @@ def read_json(self, file_path: str) -> Any: raise ValueError("Invalid JSON file") from e def read_yaml(self, file_path: str) -> Any: - logger.debug(f"FILE_READER: read from yaml {file_path}") + logger.debug(f"FILE_HANDLER: read from yaml {file_path}") with self.fs.open(file_path, "r") as f: try: return yaml.safe_load(f) diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py index 430462285..c20c4cf48 100644 --- a/tests/unit/experimental/pipeline/config/test_runner.py +++ b/tests/unit/experimental/pipeline/config/test_runner.py @@ -45,7 +45,7 @@ def test_pipeline_runner_from_config() -> None: @patch("neo4j_graphrag.experimental.pipeline.config.runner.PipelineRunner.from_config") -@patch("neo4j_graphrag.utils.file_reader.FileReader.read") +@patch("neo4j_graphrag.utils.file_handler.FileHandler.read") def test_pipeline_runner_from_config_file( mock_read: Mock, mock_from_config: Mock ) -> None: From d7b4aab053314e8f98443df9e0190f59cd0a6ceb Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 14:37:07 +0200 Subject: [PATCH 03/10] Tests --- src/neo4j_graphrag/utils/file_handler.py | 33 ++++--- tests/unit/utils/test_file_handler.py | 112 +++++++++++++++++++++++ 2 files changed, 131 insertions(+), 14 deletions(-) create mode 100644 tests/unit/utils/test_file_handler.py diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index c8689bcab..2e4f3a9b5 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -54,36 +54,41 @@ class FileHandler: def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: self.fs = fs or LocalFileSystem() - def read_json(self, file_path: str) -> Any: + def read_json(self, file_path: Union[str, Path]) -> Any: logger.debug(f"FILE_HANDLER: read from json {file_path}") - with self.fs.open(file_path, "r") as f: + path = self._check_file_exists(file_path) + with self.fs.open(str(path), "r") as f: try: return json.load(f) except json.JSONDecodeError as e: raise ValueError("Invalid JSON file") from e - def read_yaml(self, file_path: str) -> Any: + def read_yaml(self, file_path: Union[str, Path]) -> Any: logger.debug(f"FILE_HANDLER: read from yaml {file_path}") - with self.fs.open(file_path, "r") as f: + path = self._check_file_exists(file_path) + with self.fs.open(str(path), "r") as f: try: return yaml.safe_load(f) except yaml.YAMLError as e: raise ValueError("Invalid YAML file") from e - def _guess_format_and_read(self, file_path: Union[Path, str]) -> dict[str, Any]: - path = Path(file_path) - if not path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") - extension = path.suffix.lower() + def _check_file_exists(self, path: Union[str, Path]) -> Path: + file_path = Path(path) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {path}") + return file_path + + def _guess_format_and_read(self, file_path: Path) -> Any: + extension = file_path.suffix.lower() # Note: .suffix returns an empty string if Path has no extension - # if not returning a dict, parsing will fail later on path_as_string = str(file_path) if extension in [".json"]: - return self.read_json(path_as_string) # type: ignore[no-any-return] + return self.read_json(path_as_string) if extension in [".yaml", ".yml"]: - return self.read_yaml(path_as_string) # type: ignore[no-any-return] + return self.read_yaml(path_as_string) raise ValueError(f"Unsupported extension: {extension}") - def read(self, file_path: Union[Path, str]) -> dict[str, Any]: - data = self._guess_format_and_read(file_path) + def read(self, file_path: Union[Path, str]) -> Any: + path = Path(file_path) + data = self._guess_format_and_read(path) return data diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py new file mode 100644 index 000000000..482cda556 --- /dev/null +++ b/tests/unit/utils/test_file_handler.py @@ -0,0 +1,112 @@ +from pathlib import Path +from unittest.mock import patch, Mock, mock_open + +import pytest + +from neo4j_graphrag.utils.file_handler import FileHandler + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler.read_json") +def test_file_handler_read_json_from_read_method_happy_path(mock_read_json: Mock) -> None: + handler = FileHandler() + + mock_read_json.return_value = {} + data = handler.read("file.json") + mock_read_json.assert_called_with("file.json") + assert data == {} + + mock_read_json.return_value = {} + data = handler.read("file.JSON") + mock_read_json.assert_called_with("file.JSON") + assert data == {} + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler.read_yaml") +def test_file_handler_read_yaml_from_read_method_happy_path(mock_read_yaml: Mock) -> None: + handler = FileHandler() + mock_read_yaml.return_value = {} + data = handler.read("file.yaml") + mock_read_yaml.assert_called_with("file.yaml") + assert data == {} + + mock_read_yaml.return_value = {} + data = handler.read("file.yml") + mock_read_yaml.assert_called_with("file.yml") + assert data == {} + + mock_read_yaml.return_value = {} + data = handler.read("file.YAML") + mock_read_yaml.assert_called_with("file.YAML") + assert data == {} + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_method_happy_path(mock_fs: Mock, mock_file_exists: Mock) -> None: + mock_file_exists.return_value = Path("file.json") + mock_fs_open = mock_open(read_data="{}") + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + data = handler.read_json("file.json") + mock_fs_open.assert_called_once_with("file.json", "r") + assert data == {} + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_method_happy_path(mock_fs: Mock, mock_file_exists: Mock) -> None: + mock_file_exists.return_value = Path("file.yaml") + mock_fs_open = mock_open(read_data=""" + data: 1 + """) + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + data = handler.read_yaml("file.yaml") + mock_fs_open.assert_called_once_with("file.yaml", "r") + assert data == {"data": 1} + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_read_json_file_does_not_exist(mock_file_exists: Mock) -> None: + mock_file_exists.side_effect = FileNotFoundError() + + handler = FileHandler() + with pytest.raises(FileNotFoundError): + handler.read_json("file.json") + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_invalid_json(mock_fs: Mock, mock_file_exists: Mock) -> None: + mock_file_exists.return_value = True + mock_fs_open = mock_open(read_data="{") + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + with pytest.raises(ValueError, match="Invalid JSON"): + handler.read_json("file.json") + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_read_yaml_file_does_not_exist(mock_file_exists: Mock) -> None: + mock_file_exists.side_effect = FileNotFoundError() + + handler = FileHandler() + with pytest.raises(FileNotFoundError): + handler.read_json("file.yaml") + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_invalid_yaml(mock_fs: Mock, mock_file_exists: Mock) -> None: + mock_file_exists.return_value = True + mock_fs_open = mock_open(read_data=""" + data: [ + """) + mock_fs.return_value.open = mock_fs_open + + handler = FileHandler() + with pytest.raises(ValueError, match="Invalid YAML"): + handler.read_yaml("file.yaml") From 4153011652b99b73c958bd5f019501f7fc5a4945 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 14:39:10 +0200 Subject: [PATCH 04/10] Ruff --- tests/unit/utils/test_file_handler.py | 36 +++++++++++++++++++-------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py index 482cda556..d9573683f 100644 --- a/tests/unit/utils/test_file_handler.py +++ b/tests/unit/utils/test_file_handler.py @@ -7,7 +7,9 @@ @patch("neo4j_graphrag.utils.file_handler.FileHandler.read_json") -def test_file_handler_read_json_from_read_method_happy_path(mock_read_json: Mock) -> None: +def test_file_handler_read_json_from_read_method_happy_path( + mock_read_json: Mock, +) -> None: handler = FileHandler() mock_read_json.return_value = {} @@ -22,7 +24,9 @@ def test_file_handler_read_json_from_read_method_happy_path(mock_read_json: Mock @patch("neo4j_graphrag.utils.file_handler.FileHandler.read_yaml") -def test_file_handler_read_yaml_from_read_method_happy_path(mock_read_yaml: Mock) -> None: +def test_file_handler_read_yaml_from_read_method_happy_path( + mock_read_yaml: Mock, +) -> None: handler = FileHandler() mock_read_yaml.return_value = {} data = handler.read("file.yaml") @@ -42,7 +46,9 @@ def test_file_handler_read_yaml_from_read_method_happy_path(mock_read_yaml: Mock @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -def test_file_handler_read_json_method_happy_path(mock_fs: Mock, mock_file_exists: Mock) -> None: +def test_file_handler_read_json_method_happy_path( + mock_fs: Mock, mock_file_exists: Mock +) -> None: mock_file_exists.return_value = Path("file.json") mock_fs_open = mock_open(read_data="{}") mock_fs.return_value.open = mock_fs_open @@ -55,11 +61,15 @@ def test_file_handler_read_json_method_happy_path(mock_fs: Mock, mock_file_exist @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -def test_file_handler_read_yaml_method_happy_path(mock_fs: Mock, mock_file_exists: Mock) -> None: +def test_file_handler_read_yaml_method_happy_path( + mock_fs: Mock, mock_file_exists: Mock +) -> None: mock_file_exists.return_value = Path("file.yaml") - mock_fs_open = mock_open(read_data=""" + mock_fs_open = mock_open( + read_data=""" data: 1 - """) + """ + ) mock_fs.return_value.open = mock_fs_open handler = FileHandler() @@ -79,7 +89,9 @@ def test_file_handler_read_json_file_does_not_exist(mock_file_exists: Mock) -> N @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -def test_file_handler_read_json_invalid_json(mock_fs: Mock, mock_file_exists: Mock) -> None: +def test_file_handler_read_json_invalid_json( + mock_fs: Mock, mock_file_exists: Mock +) -> None: mock_file_exists.return_value = True mock_fs_open = mock_open(read_data="{") mock_fs.return_value.open = mock_fs_open @@ -100,11 +112,15 @@ def test_file_handler_read_yaml_file_does_not_exist(mock_file_exists: Mock) -> N @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -def test_file_handler_read_yaml_invalid_yaml(mock_fs: Mock, mock_file_exists: Mock) -> None: +def test_file_handler_read_yaml_invalid_yaml( + mock_fs: Mock, mock_file_exists: Mock +) -> None: mock_file_exists.return_value = True - mock_fs_open = mock_open(read_data=""" + mock_fs_open = mock_open( + read_data=""" data: [ - """) + """ + ) mock_fs.return_value.open = mock_fs_open handler = FileHandler() From a3914ef59c4aeb87a4794eccc4a5696a2ad9c982 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 19:13:11 +0200 Subject: [PATCH 05/10] Add writer capability --- .../experimental/components/schema.py | 14 +- src/neo4j_graphrag/utils/file_handler.py | 97 +++++++++++--- tests/unit/utils/test_file_handler.py | 120 +++++++++++++++++- 3 files changed, 208 insertions(+), 23 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 0cf10b381..eb5cfd123 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -15,7 +15,6 @@ from __future__ import annotations import json -import yaml import logging from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence from pathlib import Path @@ -165,8 +164,9 @@ def store_as_json(self, file_path: str) -> None: Args: file_path (str): The path where the schema configuration will be saved. """ - with open(file_path, "w") as f: - json.dump(self.model_dump(), f, indent=2) + data = self.model_dump(mode="json") + file_handler = FileHandler() + file_handler.write_json(data, file_path) def store_as_yaml(self, file_path: str) -> None: """ @@ -176,8 +176,8 @@ def store_as_yaml(self, file_path: str) -> None: file_path (str): The path where the schema configuration will be saved. """ data = self.model_dump(mode="json") - with open(file_path, "w") as f: - yaml.dump(data, f, default_flow_style=False, sort_keys=False) + file_handler = FileHandler() + file_handler.write_yaml(data, file_path) @classmethod def from_file(cls, file_path: Union[str, Path]) -> Self: @@ -193,9 +193,9 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: GraphSchema: The loaded schema configuration. """ file_path = Path(file_path) - reader = FileHandler() + file_handler = FileHandler() try: - data = reader.read(file_path) + data = file_handler.read(file_path) except ValueError: raise diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index 2e4f3a9b5..9ffa9acd2 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -51,9 +51,23 @@ class FileHandler: """ + JSON_VALID_EXTENSIONS = (".json",) + YALM_VALID_EXTENSIONS = (".yaml", ".yml") + def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: self.fs = fs or LocalFileSystem() + def _get_file_extension(self, path: Union[str, Path]) -> str: + p = Path(path) + extension = p.suffix.lower() + return extension + + def _check_file_exists(self, path: Union[str, Path]) -> Path: + file_path = Path(path) + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {path}") + return file_path + def read_json(self, file_path: Union[str, Path]) -> Any: logger.debug(f"FILE_HANDLER: read from json {file_path}") path = self._check_file_exists(file_path) @@ -72,23 +86,76 @@ def read_yaml(self, file_path: Union[str, Path]) -> Any: except yaml.YAMLError as e: raise ValueError("Invalid YAML file") from e - def _check_file_exists(self, path: Union[str, Path]) -> Path: - file_path = Path(path) - if not file_path.exists(): - raise FileNotFoundError(f"File not found: {path}") - return file_path - - def _guess_format_and_read(self, file_path: Path) -> Any: - extension = file_path.suffix.lower() + def _guess_format_and_read(self, file_path: Union[str, Path]) -> Any: + extension = self._get_file_extension(file_path) # Note: .suffix returns an empty string if Path has no extension - path_as_string = str(file_path) - if extension in [".json"]: - return self.read_json(path_as_string) - if extension in [".yaml", ".yml"]: - return self.read_yaml(path_as_string) + if extension in self.JSON_VALID_EXTENSIONS: + return self.read_json(file_path) + if extension in self.YALM_VALID_EXTENSIONS: + return self.read_yaml(file_path) raise ValueError(f"Unsupported extension: {extension}") def read(self, file_path: Union[Path, str]) -> Any: - path = Path(file_path) - data = self._guess_format_and_read(path) + data = self._guess_format_and_read(file_path) return data + + def _check_file_can_be_written( + self, path: Union[str, Path], overwrite: bool = False + ) -> None: + if overwrite: + # we can overwrite, so no matter if file already exists or not + return + try: + self._check_file_exists(path) + # did not raise, meaning the file exists + raise ValueError("File already exists. Use overwrite=True to overwrite.") + except FileNotFoundError: + # file not found all godo + pass + + def write_json( + self, + data: Any, + file_path: Union[Path, str], + overwrite: bool = False, + **extra_kwargs: Any, + ) -> None: + self._check_file_can_be_written(file_path, overwrite) + fp = str(file_path) + kwargs: dict[str, Any] = { + "indent": 2, + } + kwargs.update(extra_kwargs) + with self.fs.open(fp, "w") as f: + json.dump(data, f, **kwargs) + + def write_yaml( + self, + data: Any, + file_path: Union[Path, str], + overwrite: bool = False, + **extra_kwargs: Any, + ) -> None: + self._check_file_can_be_written(file_path, overwrite) + fp = str(file_path) + kwargs: dict[str, Any] = { + "default_flow_style": False, + "sort_keys": True, + } + kwargs.update(extra_kwargs) + with self.fs.open(fp, "w") as f: + yaml.safe_dump(data, f, **kwargs) + + def write( + self, + data: Any, + file_path: Union[Path, str], + overwrite: bool = False, + **extra_kwargs: Any, + ) -> None: + extension = self._get_file_extension(file_path) + if extension in self.JSON_VALID_EXTENSIONS: + return self.write_json(data, file_path, overwrite=overwrite, **extra_kwargs) + if extension in self.YALM_VALID_EXTENSIONS: + return self.write_yaml(data, file_path, overwrite=overwrite, **extra_kwargs) + raise ValueError(f"Unsupported extension: {extension}") diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py index d9573683f..a66d23a34 100644 --- a/tests/unit/utils/test_file_handler.py +++ b/tests/unit/utils/test_file_handler.py @@ -107,7 +107,7 @@ def test_file_handler_read_yaml_file_does_not_exist(mock_file_exists: Mock) -> N handler = FileHandler() with pytest.raises(FileNotFoundError): - handler.read_json("file.yaml") + handler.read_yaml("file.yaml") @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @@ -126,3 +126,121 @@ def test_file_handler_read_yaml_invalid_yaml( handler = FileHandler() with pytest.raises(ValueError, match="Invalid YAML"): handler.read_yaml("file.yaml") + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") +def test_file_handler_file_can_be_written_file(mock_file_exists: Mock) -> None: + # file does not exist + mock_file_exists.side_effect = FileNotFoundError() + handler = FileHandler() + # nothing happens = all good + handler._check_file_can_be_written(Path("file.json")) + # file exist, overwrite is False + mock_file_exists.side_effect = None + handler = FileHandler() + with pytest.raises(ValueError): + handler._check_file_can_be_written(Path("file.json"), overwrite=False) + # file exists, overwrite is True + mock_file_exists.side_effect = None + handler = FileHandler() + # nothing happens = all good + handler._check_file_can_be_written(Path("file.json"), overwrite=True) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +@patch("neo4j_graphrag.utils.file_handler.json") +def test_file_handler_write_json_happy_path( + mock_json_module: Mock, + mock_file_can_be_written: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler.write_json({"some": "data"}, "file.json") + mock_fs_open.assert_called_once_with("file.json", "w") + mock_json_module.dump.assert_called_with( + {"some": "data"}, mock_fs_open.return_value, indent=2 + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +@patch("neo4j_graphrag.utils.file_handler.json") +def test_file_handler_write_json_extra_kwargs_happy_path( + mock_json_module: Mock, + mock_file_can_be_written: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler.write_json({"some": "data"}, "file.json", indent=4, default=str) + mock_fs_open.assert_called_once_with("file.json", "w") + mock_json_module.dump.assert_called_with( + {"some": "data"}, mock_fs_open.return_value, indent=4, default=str + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +@patch("neo4j_graphrag.utils.file_handler.yaml") +def test_file_handler_write_yaml_happy_path( + mock_yaml_module: Mock, + mock_file_can_be_written: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler.write_yaml({"some": "data"}, "file.yaml") + mock_fs_open.assert_called_once_with("file.yaml", "w") + mock_yaml_module.safe_dump.assert_called_with( + {"some": "data"}, + mock_fs_open.return_value, + default_flow_style=False, + sort_keys=True, + ) + + +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +@patch("neo4j_graphrag.utils.file_handler.yaml") +def test_file_handler_write_yaml_extra_kwargs_happy_path( + mock_yaml_module: Mock, + mock_file_can_be_written: Mock, + mock_fs: Mock, +) -> None: + mock_fs_open = mock_open() + mock_fs.return_value.open = mock_fs_open + file_handler = FileHandler() + file_handler.write_yaml( + {"some": "data"}, "file.json", default_flow_style="toto", other_keyword=42 + ) + mock_fs_open.assert_called_once_with("file.json", "w") + mock_yaml_module.safe_dump.assert_called_with( + {"some": "data"}, + mock_fs_open.return_value, + default_flow_style="toto", + sort_keys=True, + other_keyword=42, + ) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler.write_json") +def test_file_handler_write_json_from_write_method_happy_path( + mock_write_json: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.json") + mock_write_json.assert_called_with("data", "file.json", overwrite=False) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler.write_yaml") +def test_file_handler_write_yaml_from_write_method_happy_path( + mock_write_yaml: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.yaml") + mock_write_yaml.assert_called_with("data", "file.yaml", overwrite=False) From 1fecf7dd427e8eb15bed741899b5e74efbaf1944 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 15 May 2025 19:24:15 +0200 Subject: [PATCH 06/10] Add docstrings to FileHandler --- .../experimental/components/schema.py | 11 ++- src/neo4j_graphrag/utils/file_handler.py | 82 +++++++++++++++++-- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index eb5cfd123..64be4140c 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -157,27 +157,30 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]: def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: return self._relationship_type_index.get(label) - def store_as_json(self, file_path: str) -> None: + def store_as_json(self, file_path: str, overwrite: bool = False) -> None: """ Save the schema configuration to a JSON file. Args: file_path (str): The path where the schema configuration will be saved. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. """ data = self.model_dump(mode="json") file_handler = FileHandler() - file_handler.write_json(data, file_path) + file_handler.write_json(data, file_path, overwrite=overwrite) - def store_as_yaml(self, file_path: str) -> None: + def store_as_yaml(self, file_path: str, overwrite: bool = False) -> None: """ Save the schema configuration to a YAML file. Args: file_path (str): The path where the schema configuration will be saved. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + """ data = self.model_dump(mode="json") file_handler = FileHandler() - file_handler.write_yaml(data, file_path) + file_handler.write_yaml(data, file_path, overwrite=overwrite) @classmethod def from_file(cls, file_path: Union[str, Path]) -> Self: diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index 9ffa9acd2..c5a56dac9 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -25,7 +25,7 @@ class FileHandler: - """Utility class to read JSON or YAML files. + """Utility class to read/write JSON or YAML files. File format is guessed from the extension. Supported extensions are (lower or upper case): @@ -37,17 +37,16 @@ class FileHandler: .. code-block:: python - from pathlib import Path from neo4j_graphrag.utils.file_handler import FileHandler handler = FileHandler() - handler.read(Path("my_file.json")) + handler.read("my_file.json") If reading a file with a different extension but still in JSON or YAML format, it is possible to call directly the `read_json` or `read_yaml` methods: .. code-block:: python - handler.read_yaml(Path("my_file.txt")) + handler.read_yaml("my_file.txt") """ @@ -69,6 +68,17 @@ def _check_file_exists(self, path: Union[str, Path]) -> Path: return file_path def read_json(self, file_path: Union[str, Path]) -> Any: + """Reads a JSON file. If file does not exist, raises FileNotFoundError. + + Args: + file_path (Union[str, Path]): The path of the JSON file. + + Raises: + FileNotFoundError: If file does not exist. + + Returns: + The parsed content of the JSON file. + """ logger.debug(f"FILE_HANDLER: read from json {file_path}") path = self._check_file_exists(file_path) with self.fs.open(str(path), "r") as f: @@ -78,6 +88,17 @@ def read_json(self, file_path: Union[str, Path]) -> Any: raise ValueError("Invalid JSON file") from e def read_yaml(self, file_path: Union[str, Path]) -> Any: + """Reads a YAML file. If file does not exist, raises FileNotFoundError. + + Args: + file_path (Union[str, Path]): The path of the YAML file. + + Raises: + FileNotFoundError: If file does not exist. + + Returns: + The parsed content of the YAML file. + """ logger.debug(f"FILE_HANDLER: read from yaml {file_path}") path = self._check_file_exists(file_path) with self.fs.open(str(path), "r") as f: @@ -86,7 +107,20 @@ def read_yaml(self, file_path: Union[str, Path]) -> Any: except yaml.YAMLError as e: raise ValueError("Invalid YAML file") from e - def _guess_format_and_read(self, file_path: Union[str, Path]) -> Any: + def read(self, file_path: Union[Path, str]) -> Any: + """Try to infer file type from its extension and returns its + parsed content. + + Args: + file_path (Union[str, Path]): The path of the JSON file.: + + Raises: + FileNotFoundError: If file does not exist. + ValueError: If file extension is invalid or the file can not be parsed + (e.g. invalid JSON or YAML) + + Returns: the parsed content of the file. + """ extension = self._get_file_extension(file_path) # Note: .suffix returns an empty string if Path has no extension if extension in self.JSON_VALID_EXTENSIONS: @@ -95,13 +129,21 @@ def _guess_format_and_read(self, file_path: Union[str, Path]) -> Any: return self.read_yaml(file_path) raise ValueError(f"Unsupported extension: {extension}") - def read(self, file_path: Union[Path, str]) -> Any: - data = self._guess_format_and_read(file_path) - return data - def _check_file_can_be_written( self, path: Union[str, Path], overwrite: bool = False ) -> None: + """Check whether the file can be written to path with the following conditions: + + - If overwrite is set to True, file can always be written, any existing file will be overwritten. + - If overwrite is set to False, and file already exists, file will not be overwritten. + + Args: + path (Union[str, Path]): The path of the target file. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + + Raises: + ValueError: If file can not be written according to the above rules. + """ if overwrite: # we can overwrite, so no matter if file already exists or not return @@ -120,6 +162,17 @@ def write_json( overwrite: bool = False, **extra_kwargs: Any, ) -> None: + """Writes data to a JSON file + + Args: + data (Any): The data to write. + file_path (Union[str, Path]): The path of the JSON file. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + extra_kwargs (Any): Additional arguments passed to json.dump (e.g.: indent...). Note: a default indent=4 is applied. + + Raises: + ValueError: If file can not be written according to the above rules. + """ self._check_file_can_be_written(file_path, overwrite) fp = str(file_path) kwargs: dict[str, Any] = { @@ -136,6 +189,16 @@ def write_yaml( overwrite: bool = False, **extra_kwargs: Any, ) -> None: + """Writes data to a YAML file + + Args: + data (Any): The data to write. + file_path (Union[str, Path]): The path of the YAML file. + overwrite (bool): If set to True, existing file will be overwritten. Default to False. + extra_kwargs (Any): Additional arguments passed to yaml.safe_dump. Note that we apply the following defaults: + - "default_flow_style": False + - "sort_keys": True + """ self._check_file_can_be_written(file_path, overwrite) fp = str(file_path) kwargs: dict[str, Any] = { @@ -153,6 +216,7 @@ def write( overwrite: bool = False, **extra_kwargs: Any, ) -> None: + """Guess file type and write it.""" extension = self._get_file_extension(file_path) if extension in self.JSON_VALID_EXTENSIONS: return self.write_json(data, file_path, overwrite=overwrite, **extra_kwargs) From b1f8301322f1f6a81876b5e285d404c46e0596e3 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 14:51:01 +0200 Subject: [PATCH 07/10] Introduce a 'format' parameter to reduce code duplication --- docs/source/user_guide_kg_builder.rst | 4 +- .../schema_builders/schema_from_text.py | 4 +- .../experimental/components/schema.py | 50 +++--- src/neo4j_graphrag/utils/file_handler.py | 114 +++++++------ .../experimental/components/test_schema.py | 21 +-- tests/unit/utils/test_file_handler.py | 153 ++++++++++++------ 6 files changed, 214 insertions(+), 132 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index d7455a6f8..6af5e8cb3 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -832,8 +832,8 @@ You can also save and reload the extracted schema: .. code:: python # Save the schema to JSON or YAML files - extracted_schema.store_as_json("my_schema.json") - extracted_schema.store_as_yaml("my_schema.yaml") + extracted_schema.save("my_schema.json") + extracted_schema.save("my_schema.yaml") # Later, reload the schema from file from neo4j_graphrag.experimental.components.schema import GraphSchema diff --git a/examples/customize/build_graph/components/schema_builders/schema_from_text.py b/examples/customize/build_graph/components/schema_builders/schema_from_text.py index a36ef90ec..947907dff 100644 --- a/examples/customize/build_graph/components/schema_builders/schema_from_text.py +++ b/examples/customize/build_graph/components/schema_builders/schema_from_text.py @@ -85,11 +85,11 @@ async def extract_and_save_schema() -> None: print(f"Saving schema to JSON file: {JSON_FILE_PATH}") # Save the schema to JSON file - inferred_schema.store_as_json(JSON_FILE_PATH) + inferred_schema.save(JSON_FILE_PATH) print(f"Saving schema to YAML file: {YAML_FILE_PATH}") # Save the schema to YAML file - inferred_schema.store_as_yaml(YAML_FILE_PATH) + inferred_schema.save(YAML_FILE_PATH) print("\nExtracted Schema Summary:") print(f"Node types: {list(inferred_schema.node_types)}") diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 64be4140c..b2007ccb4 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -16,6 +16,7 @@ import json import logging +import warnings from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence from pathlib import Path @@ -41,7 +42,7 @@ ) from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate from neo4j_graphrag.llm import LLMInterface -from neo4j_graphrag.utils.file_handler import FileHandler +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat class PropertyType(BaseModel): @@ -157,40 +158,53 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]: def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: return self._relationship_type_index.get(label) - def store_as_json(self, file_path: str, overwrite: bool = False) -> None: + def save( + self, + file_path: Union[str, Path], + overwrite: bool = False, + format: Optional[FileFormat] = None, + ) -> None: """ - Save the schema configuration to a JSON file. + Save the schema configuration to file. Args: file_path (str): The path where the schema configuration will be saved. overwrite (bool): If set to True, existing file will be overwritten. Default to False. + format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. """ data = self.model_dump(mode="json") file_handler = FileHandler() - file_handler.write_json(data, file_path, overwrite=overwrite) - - def store_as_yaml(self, file_path: str, overwrite: bool = False) -> None: - """ - Save the schema configuration to a YAML file. + file_handler.write(data, file_path, overwrite=overwrite, format=format) - Args: - file_path (str): The path where the schema configuration will be saved. - overwrite (bool): If set to True, existing file will be overwritten. Default to False. + def store_as_json( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) - """ - data = self.model_dump(mode="json") - file_handler = FileHandler() - file_handler.write_yaml(data, file_path, overwrite=overwrite) + def store_as_yaml( + self, file_path: Union[str, Path], overwrite: bool = False + ) -> None: + warnings.warn( + "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning + ) + return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) @classmethod - def from_file(cls, file_path: Union[str, Path]) -> Self: + def from_file( + cls, file_path: Union[str, Path], format: Optional[FileFormat] = None + ) -> Self: """ Load a schema configuration from a file (either JSON or YAML). - The file format is automatically detected based on the file extension. + The file format is automatically detected based on the file extension, + unless the format parameter is set. Args: file_path (Union[str, Path]): The path to the schema configuration file. + format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). Returns: GraphSchema: The loaded schema configuration. @@ -198,7 +212,7 @@ def from_file(cls, file_path: Union[str, Path]) -> Self: file_path = Path(file_path) file_handler = FileHandler() try: - data = file_handler.read(file_path) + data = file_handler.read(file_path, format=format) except ValueError: raise diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index c5a56dac9..4c21a0094 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import json import logging from pathlib import Path @@ -24,6 +25,19 @@ logger = logging.getLogger(__name__) +class FileFormat(enum.Enum): + JSON = "json" + YAML = "yaml" + + @classmethod + def json_valid_extension(cls) -> list[str]: + return [".json"] + + @classmethod + def yaml_valid_extension(cls) -> list[str]: + return [".yaml", ".yml"] + + class FileHandler: """Utility class to read/write JSON or YAML files. @@ -50,28 +64,28 @@ class FileHandler: """ - JSON_VALID_EXTENSIONS = (".json",) - YALM_VALID_EXTENSIONS = (".yaml", ".yml") - def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: self.fs = fs or LocalFileSystem() - def _get_file_extension(self, path: Union[str, Path]) -> str: - p = Path(path) - extension = p.suffix.lower() - return extension - - def _check_file_exists(self, path: Union[str, Path]) -> Path: - file_path = Path(path) - if not file_path.exists(): + def _guess_file_format(self, path: Path) -> Optional[FileFormat]: + # Note: .suffix returns an empty string if Path has no extension + extension = path.suffix.lower() + if extension in FileFormat.json_valid_extension(): + return FileFormat.JSON + if extension in FileFormat.yaml_valid_extension(): + return FileFormat.YAML + return None + + def _check_file_exists(self, path: Path) -> Path: + if not path.exists(): raise FileNotFoundError(f"File not found: {path}") - return file_path + return path - def read_json(self, file_path: Union[str, Path]) -> Any: + def _read_json(self, path: Path) -> Any: """Reads a JSON file. If file does not exist, raises FileNotFoundError. Args: - file_path (Union[str, Path]): The path of the JSON file. + path (Path): The path of the JSON file. Raises: FileNotFoundError: If file does not exist. @@ -79,19 +93,18 @@ def read_json(self, file_path: Union[str, Path]) -> Any: Returns: The parsed content of the JSON file. """ - logger.debug(f"FILE_HANDLER: read from json {file_path}") - path = self._check_file_exists(file_path) + logger.debug(f"FILE_HANDLER: read from json {path}") with self.fs.open(str(path), "r") as f: try: return json.load(f) except json.JSONDecodeError as e: raise ValueError("Invalid JSON file") from e - def read_yaml(self, file_path: Union[str, Path]) -> Any: + def _read_yaml(self, path: Path) -> Any: """Reads a YAML file. If file does not exist, raises FileNotFoundError. Args: - file_path (Union[str, Path]): The path of the YAML file. + path (Path): The path of the YAML file. Raises: FileNotFoundError: If file does not exist. @@ -99,20 +112,23 @@ def read_yaml(self, file_path: Union[str, Path]) -> Any: Returns: The parsed content of the YAML file. """ - logger.debug(f"FILE_HANDLER: read from yaml {file_path}") - path = self._check_file_exists(file_path) + logger.debug(f"FILE_HANDLER: read from yaml {path}") with self.fs.open(str(path), "r") as f: try: return yaml.safe_load(f) except yaml.YAMLError as e: raise ValueError("Invalid YAML file") from e - def read(self, file_path: Union[Path, str]) -> Any: + def read( + self, file_path: Union[Path, str], format: Optional[FileFormat] = None + ) -> Any: """Try to infer file type from its extension and returns its parsed content. Args: - file_path (Union[str, Path]): The path of the JSON file.: + file_path (Union[str, Path]): The path of the JSON file. + format (Optional[FileFormat]): The file format to infer the file type from. + If not set, the format is inferred from the extension. Raises: FileNotFoundError: If file does not exist. @@ -121,24 +137,24 @@ def read(self, file_path: Union[Path, str]) -> Any: Returns: the parsed content of the file. """ - extension = self._get_file_extension(file_path) - # Note: .suffix returns an empty string if Path has no extension - if extension in self.JSON_VALID_EXTENSIONS: - return self.read_json(file_path) - if extension in self.YALM_VALID_EXTENSIONS: - return self.read_yaml(file_path) - raise ValueError(f"Unsupported extension: {extension}") - - def _check_file_can_be_written( - self, path: Union[str, Path], overwrite: bool = False - ) -> None: + path = Path(file_path) + path = self._check_file_exists(path) + if not format: + format = self._guess_file_format(path) + if format == FileFormat.JSON: + return self._read_json(path) + if format == FileFormat.YAML: + return self._read_yaml(path) + raise ValueError(f"Unsupported file format: {format}") + + def _check_file_can_be_written(self, path: Path, overwrite: bool = False) -> None: """Check whether the file can be written to path with the following conditions: - If overwrite is set to True, file can always be written, any existing file will be overwritten. - If overwrite is set to False, and file already exists, file will not be overwritten. Args: - path (Union[str, Path]): The path of the target file. + path (Path): The path of the target file. overwrite (bool): If set to True, existing file will be overwritten. Default to False. Raises: @@ -155,10 +171,10 @@ def _check_file_can_be_written( # file not found all godo pass - def write_json( + def _write_json( self, data: Any, - file_path: Union[Path, str], + file_path: Path, overwrite: bool = False, **extra_kwargs: Any, ) -> None: @@ -166,14 +182,13 @@ def write_json( Args: data (Any): The data to write. - file_path (Union[str, Path]): The path of the JSON file. + file_path (Path): The path of the JSON file. overwrite (bool): If set to True, existing file will be overwritten. Default to False. extra_kwargs (Any): Additional arguments passed to json.dump (e.g.: indent...). Note: a default indent=4 is applied. Raises: ValueError: If file can not be written according to the above rules. """ - self._check_file_can_be_written(file_path, overwrite) fp = str(file_path) kwargs: dict[str, Any] = { "indent": 2, @@ -182,10 +197,10 @@ def write_json( with self.fs.open(fp, "w") as f: json.dump(data, f, **kwargs) - def write_yaml( + def _write_yaml( self, data: Any, - file_path: Union[Path, str], + file_path: Path, overwrite: bool = False, **extra_kwargs: Any, ) -> None: @@ -193,13 +208,12 @@ def write_yaml( Args: data (Any): The data to write. - file_path (Union[str, Path]): The path of the YAML file. + file_path (Path): The path of the YAML file. overwrite (bool): If set to True, existing file will be overwritten. Default to False. extra_kwargs (Any): Additional arguments passed to yaml.safe_dump. Note that we apply the following defaults: - "default_flow_style": False - "sort_keys": True """ - self._check_file_can_be_written(file_path, overwrite) fp = str(file_path) kwargs: dict[str, Any] = { "default_flow_style": False, @@ -214,12 +228,16 @@ def write( data: Any, file_path: Union[Path, str], overwrite: bool = False, + format: Optional[FileFormat] = None, **extra_kwargs: Any, ) -> None: """Guess file type and write it.""" - extension = self._get_file_extension(file_path) - if extension in self.JSON_VALID_EXTENSIONS: - return self.write_json(data, file_path, overwrite=overwrite, **extra_kwargs) - if extension in self.YALM_VALID_EXTENSIONS: - return self.write_yaml(data, file_path, overwrite=overwrite, **extra_kwargs) - raise ValueError(f"Unsupported extension: {extension}") + path = Path(file_path) + self._check_file_can_be_written(path, overwrite) + if not format: + format = self._guess_file_format(path) + if format == FileFormat.JSON: + return self._write_json(data, path, overwrite=overwrite, **extra_kwargs) + if format == FileFormat.YAML: + return self._write_yaml(data, path, overwrite=overwrite, **extra_kwargs) + raise ValueError(f"Unsupported file format: {format}") diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 15e0d674c..1aec51261 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -35,6 +35,7 @@ from neo4j_graphrag.generation import PromptTemplate from neo4j_graphrag.llm.types import LLMResponse +from neo4j_graphrag.utils.file_handler import FileFormat @pytest.fixture @@ -383,13 +384,13 @@ async def test_schema_from_text_llm_params( @pytest.mark.asyncio -async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: +async def test_schema_config_save_json(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # create file path json_path = os.path.join(temp_dir, "schema.json") # store the schema config - graph_schema.store_as_json(json_path) + graph_schema.save(json_path) # verify the file exists and has content assert os.path.exists(json_path) @@ -403,13 +404,13 @@ async def test_schema_config_store_as_json(graph_schema: GraphSchema) -> None: @pytest.mark.asyncio -async def test_schema_config_store_as_yaml(graph_schema: GraphSchema) -> None: +async def test_schema_config_save_yaml(graph_schema: GraphSchema) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Create file path yaml_path = os.path.join(temp_dir, "schema.yaml") # Store the schema config - graph_schema.store_as_yaml(yaml_path) + graph_schema.save(yaml_path) # Verify the file exists and has content assert os.path.exists(yaml_path) @@ -431,9 +432,9 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: yml_path = os.path.join(temp_dir, "schema.yml") # store the schema config in the different formats - graph_schema.store_as_json(json_path) - graph_schema.store_as_yaml(yaml_path) - graph_schema.store_as_yaml(yml_path) + graph_schema.save(json_path) + graph_schema.save(yaml_path) + graph_schema.save(yml_path) # load using from_file which should detect the format based on extension json_schema = GraphSchema.from_file(json_path) @@ -452,9 +453,11 @@ async def test_schema_config_from_file(graph_schema: GraphSchema) -> None: # verify an unsupported extension raises the correct error txt_path = os.path.join(temp_dir, "schema.txt") - graph_schema.store_as_json(txt_path) # Store as JSON but with .txt extension + graph_schema.save( + txt_path, format=FileFormat.JSON + ) # Store as JSON but with .txt extension - with pytest.raises(ValueError, match="Unsupported extension: .txt"): + with pytest.raises(ValueError, match="Unsupported file format: None"): GraphSchema.from_file(txt_path) diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py index a66d23a34..294624fb4 100644 --- a/tests/unit/utils/test_file_handler.py +++ b/tests/unit/utils/test_file_handler.py @@ -1,70 +1,101 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path from unittest.mock import patch, Mock, mock_open import pytest -from neo4j_graphrag.utils.file_handler import FileHandler +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat -@patch("neo4j_graphrag.utils.file_handler.FileHandler.read_json") +def test_file_handler_guess_format() -> None: + handler = FileHandler() + + assert handler._guess_file_format(Path("file.json")) == FileFormat.JSON + assert handler._guess_file_format(Path("file.JSON")) == FileFormat.JSON + assert handler._guess_file_format(Path("file.yaml")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.YAML")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.yml")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.YML")) == FileFormat.YAML + assert handler._guess_file_format(Path("file.txt")) is None + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._read_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") def test_file_handler_read_json_from_read_method_happy_path( + mock_file_exists: Mock, mock_read_json: Mock, ) -> None: handler = FileHandler() + mock_file_exists.return_value = Path("file.json") mock_read_json.return_value = {} data = handler.read("file.json") - mock_read_json.assert_called_with("file.json") + mock_read_json.assert_called_with(Path("file.json")) assert data == {} + mock_file_exists.return_value = Path("file.JSON") mock_read_json.return_value = {} data = handler.read("file.JSON") - mock_read_json.assert_called_with("file.JSON") + mock_read_json.assert_called_with(Path("file.JSON")) assert data == {} -@patch("neo4j_graphrag.utils.file_handler.FileHandler.read_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._read_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") def test_file_handler_read_yaml_from_read_method_happy_path( + mock_file_exists: Mock, mock_read_yaml: Mock, ) -> None: + mock_file_exists.return_value = Path("file.yaml") handler = FileHandler() mock_read_yaml.return_value = {} data = handler.read("file.yaml") - mock_read_yaml.assert_called_with("file.yaml") + mock_read_yaml.assert_called_with(Path("file.yaml")) assert data == {} + mock_file_exists.return_value = Path("file.yml") mock_read_yaml.return_value = {} data = handler.read("file.yml") - mock_read_yaml.assert_called_with("file.yml") + mock_read_yaml.assert_called_with(Path("file.yml")) assert data == {} + mock_file_exists.return_value = Path("file.YAML") mock_read_yaml.return_value = {} data = handler.read("file.YAML") - mock_read_yaml.assert_called_with("file.YAML") + mock_read_yaml.assert_called_with(Path("file.YAML")) assert data == {} -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") def test_file_handler_read_json_method_happy_path( - mock_fs: Mock, mock_file_exists: Mock + mock_fs: Mock, ) -> None: - mock_file_exists.return_value = Path("file.json") - mock_fs_open = mock_open(read_data="{}") + mock_fs_open = mock_open(read_data='{"data": 1}') mock_fs.return_value.open = mock_fs_open handler = FileHandler() - data = handler.read_json("file.json") + data = handler._read_json(Path("file.json")) mock_fs_open.assert_called_once_with("file.json", "r") - assert data == {} + assert data == {"data": 1} -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") def test_file_handler_read_yaml_method_happy_path( - mock_fs: Mock, mock_file_exists: Mock + mock_fs: Mock, ) -> None: - mock_file_exists.return_value = Path("file.yaml") mock_fs_open = mock_open( read_data=""" data: 1 @@ -73,49 +104,47 @@ def test_file_handler_read_yaml_method_happy_path( mock_fs.return_value.open = mock_fs_open handler = FileHandler() - data = handler.read_yaml("file.yaml") + data = handler._read_yaml(Path("file.yaml")) mock_fs_open.assert_called_once_with("file.yaml", "r") assert data == {"data": 1} -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") -def test_file_handler_read_json_file_does_not_exist(mock_file_exists: Mock) -> None: - mock_file_exists.side_effect = FileNotFoundError() - +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_json_file_does_not_exist( + mock_fs: Mock, +) -> None: + mock_fs.return_value.open.side_effect = FileNotFoundError handler = FileHandler() with pytest.raises(FileNotFoundError): - handler.read_json("file.json") + handler._read_json(Path("file.json")) -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") def test_file_handler_read_json_invalid_json( - mock_fs: Mock, mock_file_exists: Mock + mock_fs: Mock, ) -> None: - mock_file_exists.return_value = True mock_fs_open = mock_open(read_data="{") mock_fs.return_value.open = mock_fs_open handler = FileHandler() with pytest.raises(ValueError, match="Invalid JSON"): - handler.read_json("file.json") + handler._read_json(Path("file.json")) -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") -def test_file_handler_read_yaml_file_does_not_exist(mock_file_exists: Mock) -> None: - mock_file_exists.side_effect = FileNotFoundError() - +@patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") +def test_file_handler_read_yaml_file_does_not_exist( + mock_fs: Mock, +) -> None: + mock_fs.return_value.open.side_effect = FileNotFoundError handler = FileHandler() with pytest.raises(FileNotFoundError): - handler.read_yaml("file.yaml") + handler._read_yaml(Path("file.yaml")) -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") def test_file_handler_read_yaml_invalid_yaml( - mock_fs: Mock, mock_file_exists: Mock + mock_fs: Mock, ) -> None: - mock_file_exists.return_value = True mock_fs_open = mock_open( read_data=""" data: [ @@ -125,7 +154,7 @@ def test_file_handler_read_yaml_invalid_yaml( handler = FileHandler() with pytest.raises(ValueError, match="Invalid YAML"): - handler.read_yaml("file.yaml") + handler._read_yaml(Path("file.yaml")) @patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_exists") @@ -148,17 +177,15 @@ def test_file_handler_file_can_be_written_file(mock_file_exists: Mock) -> None: @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") @patch("neo4j_graphrag.utils.file_handler.json") def test_file_handler_write_json_happy_path( mock_json_module: Mock, - mock_file_can_be_written: Mock, mock_fs: Mock, ) -> None: mock_fs_open = mock_open() mock_fs.return_value.open = mock_fs_open file_handler = FileHandler() - file_handler.write_json({"some": "data"}, "file.json") + file_handler._write_json({"some": "data"}, Path("file.json")) mock_fs_open.assert_called_once_with("file.json", "w") mock_json_module.dump.assert_called_with( {"some": "data"}, mock_fs_open.return_value, indent=2 @@ -166,17 +193,15 @@ def test_file_handler_write_json_happy_path( @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") @patch("neo4j_graphrag.utils.file_handler.json") def test_file_handler_write_json_extra_kwargs_happy_path( mock_json_module: Mock, - mock_file_can_be_written: Mock, mock_fs: Mock, ) -> None: mock_fs_open = mock_open() mock_fs.return_value.open = mock_fs_open file_handler = FileHandler() - file_handler.write_json({"some": "data"}, "file.json", indent=4, default=str) + file_handler._write_json({"some": "data"}, Path("file.json"), indent=4, default=str) mock_fs_open.assert_called_once_with("file.json", "w") mock_json_module.dump.assert_called_with( {"some": "data"}, mock_fs_open.return_value, indent=4, default=str @@ -184,17 +209,15 @@ def test_file_handler_write_json_extra_kwargs_happy_path( @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") @patch("neo4j_graphrag.utils.file_handler.yaml") def test_file_handler_write_yaml_happy_path( mock_yaml_module: Mock, - mock_file_can_be_written: Mock, mock_fs: Mock, ) -> None: mock_fs_open = mock_open() mock_fs.return_value.open = mock_fs_open file_handler = FileHandler() - file_handler.write_yaml({"some": "data"}, "file.yaml") + file_handler._write_yaml({"some": "data"}, Path("file.yaml")) mock_fs_open.assert_called_once_with("file.yaml", "w") mock_yaml_module.safe_dump.assert_called_with( {"some": "data"}, @@ -205,18 +228,16 @@ def test_file_handler_write_yaml_happy_path( @patch("neo4j_graphrag.utils.file_handler.LocalFileSystem") -@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") @patch("neo4j_graphrag.utils.file_handler.yaml") def test_file_handler_write_yaml_extra_kwargs_happy_path( mock_yaml_module: Mock, - mock_file_can_be_written: Mock, mock_fs: Mock, ) -> None: mock_fs_open = mock_open() mock_fs.return_value.open = mock_fs_open file_handler = FileHandler() - file_handler.write_yaml( - {"some": "data"}, "file.json", default_flow_style="toto", other_keyword=42 + file_handler._write_yaml( + {"some": "data"}, Path("file.json"), default_flow_style="toto", other_keyword=42 ) mock_fs_open.assert_called_once_with("file.json", "w") mock_yaml_module.safe_dump.assert_called_with( @@ -228,19 +249,45 @@ def test_file_handler_write_yaml_extra_kwargs_happy_path( ) -@patch("neo4j_graphrag.utils.file_handler.FileHandler.write_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") def test_file_handler_write_json_from_write_method_happy_path( + mock_file_can_be_written: Mock, mock_write_json: Mock, ) -> None: handler = FileHandler() handler.write("data", "file.json") - mock_write_json.assert_called_with("data", "file.json", overwrite=False) + mock_write_json.assert_called_with("data", Path("file.json"), overwrite=False) -@patch("neo4j_graphrag.utils.file_handler.FileHandler.write_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") def test_file_handler_write_yaml_from_write_method_happy_path( + mock_file_can_be_written: Mock, mock_write_yaml: Mock, ) -> None: handler = FileHandler() handler.write("data", "file.yaml") - mock_write_yaml.assert_called_with("data", "file.yaml", overwrite=False) + mock_write_yaml.assert_called_with("data", Path("file.yaml"), overwrite=False) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_yaml_from_write_method_overwrite_format_happy_path( + mock_file_can_be_written: Mock, + mock_write_yaml: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.txt", format=FileFormat.YAML) + mock_write_yaml.assert_called_with("data", Path("file.txt"), overwrite=False) + + +@patch("neo4j_graphrag.utils.file_handler.FileHandler._write_json") +@patch("neo4j_graphrag.utils.file_handler.FileHandler._check_file_can_be_written") +def test_file_handler_write_json_from_write_method_overwrite_format_happy_path( + mock_file_can_be_written: Mock, + mock_write_json: Mock, +) -> None: + handler = FileHandler() + handler.write("data", "file.txt", format=FileFormat.JSON) + mock_write_json.assert_called_with("data", Path("file.txt"), overwrite=False) From ffa4a39a97b55ceb5c79c2cbfba6725c87bfe79f Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 09:26:58 +0200 Subject: [PATCH 08/10] Remove unused 'overwrite' argument --- src/neo4j_graphrag/utils/file_handler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index 4c21a0094..7c5963f14 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -175,7 +175,6 @@ def _write_json( self, data: Any, file_path: Path, - overwrite: bool = False, **extra_kwargs: Any, ) -> None: """Writes data to a JSON file @@ -183,7 +182,6 @@ def _write_json( Args: data (Any): The data to write. file_path (Path): The path of the JSON file. - overwrite (bool): If set to True, existing file will be overwritten. Default to False. extra_kwargs (Any): Additional arguments passed to json.dump (e.g.: indent...). Note: a default indent=4 is applied. Raises: @@ -201,7 +199,6 @@ def _write_yaml( self, data: Any, file_path: Path, - overwrite: bool = False, **extra_kwargs: Any, ) -> None: """Writes data to a YAML file @@ -209,7 +206,6 @@ def _write_yaml( Args: data (Any): The data to write. file_path (Path): The path of the YAML file. - overwrite (bool): If set to True, existing file will be overwritten. Default to False. extra_kwargs (Any): Additional arguments passed to yaml.safe_dump. Note that we apply the following defaults: - "default_flow_style": False - "sort_keys": True From 91efe7e085f90d9ab1161f8e16668d4fec65a556 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 09:59:59 +0200 Subject: [PATCH 09/10] Fix tests --- src/neo4j_graphrag/utils/file_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/utils/file_handler.py b/src/neo4j_graphrag/utils/file_handler.py index 7c5963f14..b351fe00a 100644 --- a/src/neo4j_graphrag/utils/file_handler.py +++ b/src/neo4j_graphrag/utils/file_handler.py @@ -233,7 +233,7 @@ def write( if not format: format = self._guess_file_format(path) if format == FileFormat.JSON: - return self._write_json(data, path, overwrite=overwrite, **extra_kwargs) + return self._write_json(data, path, **extra_kwargs) if format == FileFormat.YAML: - return self._write_yaml(data, path, overwrite=overwrite, **extra_kwargs) + return self._write_yaml(data, path, **extra_kwargs) raise ValueError(f"Unsupported file format: {format}") From 5c1c8b63d319be9141255d2f1eed37a16c0bccbf Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 10:21:55 +0200 Subject: [PATCH 10/10] Fix tests --- tests/unit/utils/test_file_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/utils/test_file_handler.py b/tests/unit/utils/test_file_handler.py index 294624fb4..5d12c05cb 100644 --- a/tests/unit/utils/test_file_handler.py +++ b/tests/unit/utils/test_file_handler.py @@ -257,7 +257,7 @@ def test_file_handler_write_json_from_write_method_happy_path( ) -> None: handler = FileHandler() handler.write("data", "file.json") - mock_write_json.assert_called_with("data", Path("file.json"), overwrite=False) + mock_write_json.assert_called_with("data", Path("file.json")) @patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") @@ -268,7 +268,7 @@ def test_file_handler_write_yaml_from_write_method_happy_path( ) -> None: handler = FileHandler() handler.write("data", "file.yaml") - mock_write_yaml.assert_called_with("data", Path("file.yaml"), overwrite=False) + mock_write_yaml.assert_called_with("data", Path("file.yaml")) @patch("neo4j_graphrag.utils.file_handler.FileHandler._write_yaml") @@ -279,7 +279,7 @@ def test_file_handler_write_yaml_from_write_method_overwrite_format_happy_path( ) -> None: handler = FileHandler() handler.write("data", "file.txt", format=FileFormat.YAML) - mock_write_yaml.assert_called_with("data", Path("file.txt"), overwrite=False) + mock_write_yaml.assert_called_with("data", Path("file.txt")) @patch("neo4j_graphrag.utils.file_handler.FileHandler._write_json") @@ -290,4 +290,4 @@ def test_file_handler_write_json_from_write_method_overwrite_format_happy_path( ) -> None: handler = FileHandler() handler.write("data", "file.txt", format=FileFormat.JSON) - mock_write_json.assert_called_with("data", Path("file.txt"), overwrite=False) + mock_write_json.assert_called_with("data", Path("file.txt"))