Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a7cd9e2
feat: add functions for correlation matrix calculations and subsetting
strixy16 Dec 16, 2024
bceafa4
style: change feature name if statements to be multi-line
strixy16 Dec 16, 2024
dc60e19
build: updated pixi lock file
strixy16 Dec 16, 2024
c426006
feat: add readii logger for errors and exceptions
strixy16 Dec 16, 2024
b0bd53c
test: add testing for getFeatureCorrelations
strixy16 Dec 16, 2024
9713455
build: add seaborn for plotting
strixy16 Dec 16, 2024
28ac67b
build: add pandas dependency
strixy16 Dec 16, 2024
f690a46
feat: add function to plot a correlation matrix as a seaborn heatmap
strixy16 Dec 16, 2024
6370597
feat: add function for plotting correlation distribution as a histogram
strixy16 Dec 16, 2024
0f9f331
feat: add correlation plot functions to init
strixy16 Dec 16, 2024
689eebf
feat: made writer class for plot figures
strixy16 Dec 16, 2024
23c87e4
fix: removing what I think is a typo import
strixy16 Dec 16, 2024
de87bef
style: add space in import statement for plot correlation
strixy16 Dec 16, 2024
6706b9f
feat: add function return types
strixy16 Dec 16, 2024
382de73
feat: specify PlotWriter save object as matplotlib Figure
strixy16 Dec 16, 2024
b07e5a6
Merge remote-tracking branch 'origin/main' into katys/add-analysis
strixy16 Dec 16, 2024
916c42b
docs: correct docstring in getFeatureCorrelations for default feature…
strixy16 Dec 16, 2024
24d41dd
docs: make function docstring oneliners imperative form
strixy16 Dec 16, 2024
d917b7a
docs: make docstring oneliner imperative
strixy16 Dec 16, 2024
293f40f
feat: add possible file extensions for plot figure, remove logger.exc…
strixy16 Dec 16, 2024
7bd422a
refactor: replace print statement with logger
strixy16 Dec 16, 2024
f564992
docs: update triangle parameter description in plotCorrelationHeatmap
strixy16 Dec 16, 2024
9b1424a
feat: add helper function to check if the subsetting of a dataframe i…
strixy16 Dec 16, 2024
724ee50
Merge remote-tracking branch 'origin/main' into katys/add-analysis fo…
strixy16 Dec 16, 2024
9af1289
docs/refactor: add parameter descriptions for PlotWriter save, change…
strixy16 Dec 17, 2024
5406928
style: remove blank line after function docstring
strixy16 Dec 17, 2024
573f8a9
refactor: remove unused import
strixy16 Dec 17, 2024
d3f713e
refactor: remove unused error variables in loadFeatureFilesFromImageT…
strixy16 Dec 17, 2024
0012f6e
refactor: change fstring to regular string in error msgs
strixy16 Dec 17, 2024
16c989b
style: remove whitespace around docstrings
strixy16 Dec 17, 2024
657b72a
feat: add all io, data, analyze functions to ruff config
strixy16 Dec 17, 2024
6dee126
style: sorted imports
strixy16 Dec 17, 2024
cd12dd2
refactor: replaced matplotlib import to specifically import Figure
strixy16 Dec 17, 2024
ce56c78
style: sort imports, remove whitespace in docstring
strixy16 Dec 17, 2024
b62b735
style: sort imports
strixy16 Dec 17, 2024
64052c0
feat: add error handling
strixy16 Dec 17, 2024
4cf18a4
refactor: correct help message for overwrite
strixy16 Dec 17, 2024
227e02a
refactor: replace matplotlib import with Figure import
strixy16 Dec 17, 2024
d583a84
style: sort imports
strixy16 Dec 17, 2024
173b212
refactor: remove io readers and data directory for now, will add in s…
strixy16 Dec 17, 2024
5be1831
feat: add check for empty dataframes in getFeatureCorrelations
strixy16 Dec 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,674 changes: 701 additions & 973 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ dependencies = [
"pydicom>=2.3.1",
"pyradiomics-bhklab>=3.1.4,<4",
"orcestra-downloader>=0.9.0,<1",
"numpy==1.26.4.*",
"seaborn>=0.13.2,<0.14",
"pandas>=2.2.3,<3"
]
requires-python = ">=3.10, <3.13"

Expand Down
14 changes: 14 additions & 0 deletions src/readii/analyze/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
""" Module to perform analysis on READII outputs """

from .correlation import getFeatureCorrelations, getVerticalSelfCorrelations, getHorizontalSelfCorrelations, getCrossCorrelationMatrix
from.plot_correlation import plotCorrelationHeatmap, plotCorrelationHistogram


__all__ = [
'getFeatureCorrelations',
'getVerticalSelfCorrelations',
'getHorizontalSelfCorrelations',
'getCrossCorrelationMatrix',
'plotCorrelationHeatmap',
'plotCorrelationHistogram'
]
167 changes: 167 additions & 0 deletions src/readii/analyze/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pandas as pd
from readii.utils import logger

def getFeatureCorrelations(vertical_features:pd.DataFrame,
horizontal_features:pd.DataFrame,
method:str = "pearson",
vertical_feature_name:str = '_vertical',
horizontal_feature_name:str = '_horizontal'):
""" Function to calculate correlation between two sets of features.

Parameters
----------
vertical_features : pd.DataFrame
Dataframe containing features to calculate correlations with. Index must be the same as the index of the horizontal_features dataframe.
horizontal_features : pd.DataFrame
Dataframe containing features to calculate correlations with. Index must be the same as the index of the vertical_features dataframe.
method : str
Method to use for calculating correlations. Default is "pearson".
vertical_feature_name : str
Name of the vertical features to use as suffix in correlation dataframe. Default is blank "".
horizontal_feature_name : str
Name of the horizontal features to use as suffix in correlation dataframe. Default is blank "".

Returns
-------
correlation_matrix : pd.DataFrame
Dataframe containing correlation values.
"""
# Check that features are dataframes
if not isinstance(vertical_features, pd.DataFrame):
msg = "vertical_features must be a pandas DataFrame"
logger.exception(msg)
raise TypeError()
if not isinstance(horizontal_features, pd.DataFrame):
msg = "horizontal_features must be a pandas DataFrame"
logger.exception(msg)
raise TypeError()

if method not in ["pearson", "spearman", "kendall"]:
msg = "Correlation method must be one of 'pearson', 'spearman', or 'kendall'."
logger.exception(msg)
raise ValueError()

if not vertical_features.index.equals(horizontal_features.index):
msg = "Vertical and horizontal features must have the same index to calculate correlation. Set the index to the intersection of patient IDs."
logger.exception(msg)
raise ValueError()

# Add _ to beginnging of feature names if they don't start with _ so they can be used as suffixes
if not vertical_feature_name.startswith("_"):
vertical_feature_name = f"_{vertical_feature_name}"
if not horizontal_feature_name.startswith("_"):
horizontal_feature_name = f"_{horizontal_feature_name}"

# Join the features into one dataframe
# Use inner join to keep only the rows that have a value in both vertical and horizontal features
features_to_correlate = vertical_features.join(horizontal_features,
how='inner',
lsuffix=vertical_feature_name,
rsuffix=horizontal_feature_name)

try:
# Calculate correlation between vertical features and horizontal features
correlation_matrix = features_to_correlate.corr(method=method)
except Exception as e:
msg = f"Error calculating correlation matrix: {e}"
logger.exception(msg)
raise e

return correlation_matrix



def getVerticalSelfCorrelations(correlation_matrix:pd.DataFrame,
num_vertical_features:int):
""" Function to get the vertical (y-axis) self correlations from a correlation matrix. Gets the top left quadrant of the correlation matrix.

Parameters
----------
correlation_matrix : pd.DataFrame
Dataframe containing the correlation matrix to get the vertical self correlations from.
num_vertical_features : int
Number of vertical features in the correlation matrix.

Returns
-------
pd.DataFrame
Dataframe containing the vertical self correlations from the correlation matrix.
"""
if num_vertical_features > correlation_matrix.shape[0]:
msg = f"Number of vertical features ({num_vertical_features}) is greater than the number of rows in the correlation matrix ({correlation_matrix.shape[0]})."
logger.exception(msg)
raise ValueError()

if num_vertical_features > correlation_matrix.shape[1]:
msg = f"Number of vertical features ({num_vertical_features}) is greater than the number of columns in the correlation matrix ({correlation_matrix.shape[1]})."
logger.exception(msg)
raise ValueError()

# Get the correlation matrix for vertical vs vertical - this is the top left corner of the matrix
return correlation_matrix.iloc[0:num_vertical_features, 0:num_vertical_features]



def getHorizontalSelfCorrelations(correlation_matrix:pd.DataFrame,
num_horizontal_features:int):
""" Function to get the horizontal (x-axis) self correlations from a correlation matrix. Gets the bottom right quadrant of the correlation matrix.

Parameters
----------
correlation_matrix : pd.DataFrame
Dataframe containing the correlation matrix to get the horizontal self correlations from.
num_horizontal_features : int
Number of horizontal features in the correlation matrix.

Returns
-------
pd.DataFrame
Dataframe containing the horizontal self correlations from the correlation matrix.
"""

if num_horizontal_features > correlation_matrix.shape[0]:
msg = f"Number of horizontal features ({num_horizontal_features}) is greater than the number of rows in the correlation matrix ({correlation_matrix.shape[0]})."
logger.exception(msg)
raise ValueError()

if num_horizontal_features > correlation_matrix.shape[1]:
msg = f"Number of horizontal features ({num_horizontal_features}) is greater than the number of columns in the correlation matrix ({correlation_matrix.shape[1]})."
logger.exception(msg)
raise ValueError()

# Get the index of the start of the horizontal correlations
start_of_horizontal_correlations = len(correlation_matrix.columns) - num_horizontal_features

# Get the correlation matrix for horizontal vs horizontal - this is the bottom right corner of the matrix
return correlation_matrix.iloc[start_of_horizontal_correlations:, start_of_horizontal_correlations:]



def getCrossCorrelationMatrix(correlation_matrix:pd.DataFrame,
num_vertical_features:int):
""" Function to get the cross correlation matrix subsection for a correlation matrix. Gets the top right quadrant of the correlation matrix so vertical and horizontal features are correctly labeled.

Parameters
----------
correlation_matrix : pd.DataFrame
Dataframe containing the correlation matrix to get the cross correlation matrix subsection from.
num_vertical_features : int
Number of vertical features in the correlation matrix.

Returns
-------
pd.DataFrame
Dataframe containing the cross correlations from the correlation matrix.
"""

if num_vertical_features > correlation_matrix.shape[0]:
msg = f"Number of vertical features ({num_vertical_features}) is greater than the number of rows in the correlation matrix ({correlation_matrix.shape[0]})."
logger.exception(msg)
raise ValueError()

if num_vertical_features > correlation_matrix.shape[1]:
msg = f"Number of vertical features ({num_vertical_features}) is greater than the number of columns in the correlation matrix ({correlation_matrix.shape[1]})."
logger.exception(msg)
raise ValueError()

return correlation_matrix.iloc[0:num_vertical_features, num_vertical_features:]
170 changes: 170 additions & 0 deletions src/readii/analyze/plot_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Optional
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from scipy.linalg import issymmetric

from readii.utils import logger

def plotCorrelationHeatmap(correlation_matrix_df:pd.DataFrame,
diagonal:bool = False,
triangle:Optional[str] = "lower",
cmap:str = "nipy_spectral",
xlabel:str = "",
ylabel:Optional[str] = "",
title:Optional[str] = "",
subtitle:Optional[str] = "",
show_tick_labels:bool = False
):
"""Function to plot a correlation heatmap.

Parameters
----------
correlation_matrix_df : pd.DataFrame
Dataframe containing the correlation matrix to plot.
diagonal : bool, optional
Whether to only plot half of the matrix. The default is False.
triangle : str, optional
Which triangle half of the matrixto plot. The default is "lower".
xlabel : str, optional
Label for the x-axis. The default is "".
ylabel : str, optional
Label for the y-axis. The default is "".
title : str, optional
Title for the plot. The default is "".
subtitle : str, optional
Subtitle for the plot. The default is "".
show_tick_labels : bool, optional
Whether to show the tick labels on the x and y axes. These would be the feature names. The default is False.

Returns
-------
corr_fig : matplotlib.pyplot.figure
Figure object containing a Seaborn heatmap.
"""

if diagonal:
logger.debug(f"Creating {triangle} traingle mask for diagonal correlation plot.")
# Set up mask for hiding half the matrix in the plot
if triangle == "lower":
# Mask out the upper right triangle half of the matrix
mask = np.triu(correlation_matrix_df)
elif triangle == "upper":
# Mask out the lower left triangle half of the matrix
mask = np.tril(correlation_matrix_df)
else:
msg = f"If diagonal is True, triangle must be either 'lower' or 'upper'. Got {triangle}."
logger.exception(msg)
raise ValueError()
else:
logger.debug("Creating full square correlation matrix plot.")
# The entire correlation matrix will be visisble in the plot
mask = None

# Set a default title if one is not provided
if not title:
title = "Correlation Heatmap"

# Set up figure and axes for the plot
corr_fig, corr_ax = plt.subplots()

# Plot the correlation matrix
corr_ax = sns.heatmap(correlation_matrix_df,
mask = mask,
cmap=cmap,
vmin=-1.0,
vmax=1.0)

if not show_tick_labels:
# Remove the individual feature names from the axes
corr_ax.set_xticklabels(labels=[])
corr_ax.set_yticklabels(labels=[])

# Set axis labels
corr_ax.set_xlabel(xlabel)
corr_ax.set_ylabel(ylabel)

# Set title and subtitle
# Suptitle is the super title, which will be above the title
plt.title(subtitle, fontsize=12)
plt.suptitle(title, fontsize=14)

return corr_fig



def plotCorrelationHistogram(correlation_matrix:pd.DataFrame,
num_bins:int = 100,
xlabel:Optional[str] = "Correlations",
ylabel:Optional[str] = "Frequency",
y_lower_bound:int = 0,
y_upper_bound:Optional[int] = None,
title:Optional[str] = "Distribution of Correlations for Features",
subtitle:Optional[str] = "",
):
""" Function to plot a distribution of correlation values for a correlation matrix.

Parameters
----------
correlation_matrix : pd.DataFrame
Dataframe containing the correlation matrix to plot.
num_bins : int, optional
Number of bins to use for the distribution plot. The default is 100.
xlabel : str, optional
Label for the x-axis. The default is "Correlations".
ylabel : str, optional
Label for the y-axis. The default is "Frequency".
y_lower_bound : int, optional
Lower bound for the y-axis of the distribution plot. The default is 0.
y_upper_bound : int, optional
Upper bound for the y-axis of the distribution plot. The default is None.
title : str, optional
Title for the plot. The default is "Distribution of Correlations for Features".
subtitle : str, optional
Subtitle for the plot. The default is "".

Returns
-------
dist_fig : plt.Figure
Figure object containing the histogram of correlation values.
bin_values : np.ndarray or list of arrays
Numpy array containing the values in each bin for the histogram.
bin_edges : np.ndarray
Numpy array containing the bin edges for the histogram.
"""

# Convert to numpy to use histogram function
feature_correlation_arr = correlation_matrix.to_numpy()

# Check if matrix is symmetric
if issymmetric(feature_correlation_arr):
print("Correlation matrix is symmetric.")
# Get only the bottom left triangle of the correlation matrix since the matrix is symmetric
lower_half_idx = np.mask_indices(feature_correlation_arr.shape[0], np.tril)
# This is a 1D array for binning and plotting
correlation_vals = feature_correlation_arr[lower_half_idx]
else:
# Flatten the matrix to a 1D array for binning and plotting
correlation_vals = feature_correlation_arr.flatten()

# Set up figure and axes for the plot
dist_fig, dist_ax = plt.subplots()

# Plot the histogram of correlation values
bin_values, bin_edges, _ = dist_ax.hist(correlation_vals, bins=num_bins)

# Set up axis labels
dist_ax.set_xlabel(xlabel)
dist_ax.set_ylabel(ylabel)

# Set axis bounds
dist_ax.set_xbound(-1.0, 1.0)
dist_ax.set_ybound(y_lower_bound, y_upper_bound)

# Set title and subtitle
# Suptitle is the super title, which will be above the title
plt.suptitle(title, fontsize=14)
plt.title(subtitle, fontsize=10)

return dist_fig, bin_values, bin_edges
Loading