Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
510 changes: 374 additions & 136 deletions pixi.lock

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions src/readii/io/writers/correlation_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar

from pandas import DataFrame

from readii.io.writers.base_writer import BaseWriter
from readii.utils import logger


class CorrelationWriterError(Exception):
"""Base exception for CorrelationWriter errors."""

pass


class CorrelationWriterIOError(CorrelationWriterError):
"""Raised when I/O operations fail."""

pass


class CorrelationWriterValidationError(CorrelationWriterError):
"""Raised when validation of writer configuration fails."""

pass

@dataclass
class CorrelationWriter(BaseWriter):
"""Class for managing file writing with customizable paths and filenames for plot figure files."""

overwrite: bool = field(
default=False,
metadata={
"help": "If True, allows overwriting existing files. If False, raises CorrelationWriterIOError."
},
)

# Make extensions immutable
VALID_EXTENSIONS: ClassVar[list[str]] = (
".csv",
".xlsx"
)
Comment on lines +40 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix type annotation for VALID_EXTENSIONS

The type hint list[str] doesn't match the tuple literal. Either change the type hint to tuple[str, ...] or convert the tuple to a list.

-    VALID_EXTENSIONS: ClassVar[list[str]] = (
+    VALID_EXTENSIONS: ClassVar[tuple[str, ...]] = (
         ".csv",
         ".xlsx"
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
VALID_EXTENSIONS: ClassVar[list[str]] = (
".csv",
".xlsx"
)
VALID_EXTENSIONS: ClassVar[tuple[str, ...]] = (
".csv",
".xlsx"
)


def __post_init__(self) -> None:
"""Validate writer configuration."""
super().__post_init__()

if not any(self.filename_format.endswith(ext) for ext in self.VALID_EXTENSIONS):
msg = f"Invalid filename format {self.filename_format}. Must end with one of {self.VALID_EXTENSIONS}."
raise CorrelationWriterValidationError(msg)

def save(self, correlation_df:DataFrame, **kwargs: str) -> Path:
"""Save the correlation dataframe to a .csv file.

Parameters
----------
correlation_df : DataFrame
The correlation dataframe to save.
**kwargs : str
Additional keyword arguments to pass to the filename format.

Returns
-------
Path
The path to the saved file.

Raises
------
CorrelationWriterIOError
If an error occurs during file writing.
CorrelationWriterValidationError
If the filename format is invalid.
"""
logger.debug("Saving.", kwargs=kwargs)

# Generate the output path
out_path = self.resolve_path(**kwargs)

# Check if the output path already exists
if out_path.exists():
if not self.overwrite:
msg = f"File {out_path} already exists. \nSet {self.__class__.__name__}.overwrite to True to overwrite."
raise CorrelationWriterIOError(msg)
else:
logger.warning(f"File {out_path} already exists. Overwriting.")

# Check if the correlation dataframe is a DataFrame
if not isinstance(correlation_df, DataFrame):
msg = f"Correlation dataframe must be a pandas DataFrame, got {type(correlation_df)}"
raise CorrelationWriterValidationError(msg)

# Check if the correlation dataframe is empty
if correlation_df.empty:
msg = "Correlation dataframe is empty"
raise CorrelationWriterValidationError(msg)

# Check that the columns and index of the correlation dataframe are the same
if not correlation_df.columns.equals(correlation_df.index):
msg = "Correlation dataframe columns and index are not the same"
raise CorrelationWriterValidationError(msg)

logger.debug("Saving correlation dataframe to file", out_path=out_path)
try:
match out_path.suffix:
case ".csv":
correlation_df.to_csv(out_path, index=True, index_label="")
case ".xlsx":
correlation_df.to_excel(out_path, index=True, index_label="")
case _:
msg = f"Invalid file extension {out_path.suffix}. Must be one of {self.VALID_EXTENSIONS}."
raise CorrelationWriterValidationError(msg)
except Exception as e:
msg = f"Error saving correlation dataframe to file {out_path}: {e}"
raise CorrelationWriterIOError(msg) from e
else:
logger.info("Correlation dataframe saved successfully.", out_path=out_path)
return out_path


if __name__ == "__main__": # pragma: no cover
from rich import print # noqa

plot_writer = CorrelationWriter(
root_directory=Path("TRASH", "correlation_writer_examples"),
filename_format="{DatasetName}_{VerticalFeatureType}_{HorizontalFeatureType}_{CorrelationType}_correlations.csv",
overwrite=True,
create_dirs=True
)

print(plot_writer)
File renamed without changes.
59 changes: 59 additions & 0 deletions tests/io/writers/test_correlation_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

import pandas as pd
import numpy as np
from pathlib import Path
from readii.analyze.correlation import getFeatureCorrelations
from readii.io.writers.correlation_writer import CorrelationWriter, CorrelationWriterValidationError, CorrelationWriterError, CorrelationWriterIOError # type: ignore

@pytest.fixture
def random_feature_correlations():
# Create a 10x10 matrix with random float values between 0 and 1
random_matrix = np.random.default_rng(seed=10).random((10,10))
# Convert to dataframe and name the columns feature1, feature2, etc.
random_df = pd.DataFrame(random_matrix, columns=[f"feature_{i+1}" for i in range(10)])
# Calculate correlation
return getFeatureCorrelations(random_df, random_df)


@pytest.fixture
def corr_writer(tmp_path):
"""Fixture for creating a CorrelationWriter instance."""
return CorrelationWriter(
root_directory=tmp_path,
filename_format="{CorrelationType}_correlation_matrix.csv",
overwrite=False,
create_dirs=True,
)

@pytest.mark.parametrize("correlation_df", ["not_a_correlation_df", 12345, pd.DataFrame()])
def test_save_invalid_correlation(corr_writer, correlation_df):
"""Test saving an invalid image."""
with pytest.raises(CorrelationWriterValidationError):
corr_writer.save(correlation_df, CorrelationType="Pearson")

@pytest.mark.parametrize("correlation_df", ["random_feature_correlations"])
def test_save_valid_correlation(corr_writer, request, correlation_df):
"""Test saving a valid correlation dataframe."""
correlation_df = request.getfixturevalue(correlation_df)
out_path = corr_writer.save(correlation_df, CorrelationType="Pearson")
assert out_path.exists()

def test_save_existing_file_without_overwrite(corr_writer, random_feature_correlations):
"""Test saving when file already exists and overwrite is False."""
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
with pytest.raises(CorrelationWriterIOError):
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")

def test_save_existing_file_with_overwrite(corr_writer, random_feature_correlations):
"""Test saving when file already exists and overwrite is True."""
corr_writer.overwrite = True
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
assert corr_writer.save(random_feature_correlations, CorrelationType="Pearson").exists()

@pytest.mark.parametrize("filename_format", ["{CorrelationType}_correlation_matrix.csv", "{CorrelationType}_correlation_matrix.xlsx"])
def test_save_with_different_filename_formats(corr_writer, random_feature_correlations, filename_format):
"""Test saving with different filename formats."""
corr_writer.filename_format = filename_format
out_path = corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
assert out_path.exists()
File renamed without changes.
Loading