-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
enhancementNew feature or requestNew feature or request
Description
β
Feature Request: Add plot_silhouette_analysis
to unsupervised
module for visualizing clustering quality
π Summary
Add a new method plot_silhouette_analysis
to the unsupervised
module of the DataScienceUtils
package. This utility enables comprehensive silhouette analysis of clustering results, allowing users to visually assess cluster cohesion and separation using silhouette scores.
π― Why This Feature?
Clustering is inherently unsupervised β and without true labels, users often struggle to determine whether clusters make sense. Silhouette analysis provides a well-established visual metric that shows:
- How close each sample is to its own cluster vs others (cohesion vs separation)
- Whether the current number of clusters is appropriate
- Which clusters are well-formed and which are problematic
This function will accelerate exploratory clustering analysis and guide better hyperparameter tuning.
π§© Target Module
Place the new method under:
datascienceutils/unsupervised/plot_silhouette_analysis
π οΈ Full Implementation
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_samples, silhouette_score
from typing import Optional, Union, Tuple
import warnings
def plot_silhouette_analysis(
X: Union[np.ndarray, 'pd.DataFrame'],
cluster_labels: np.ndarray,
metric: str = 'euclidean',
figsize: Tuple[int, int] = (12, 8),
colors: Optional[list] = None,
show_avg_line: bool = True,
show_cluster_centers: bool = False, # accepted for API consistency, but ignored in this issue
cluster_centers: Optional[np.ndarray] = None, # ignored
title: Optional[str] = None,
return_scores: bool = False
) -> Union[plt.Figure, Tuple[plt.Figure, dict]]:
"""
Create a silhouette analysis plot for clustering results.
Parameters
----------
X : array-like or DataFrame of shape (n_samples, n_features)
Input data used for clustering.
cluster_labels : array-like of shape (n_samples,)
Cluster labels for each sample.
metric : str, default 'euclidean'
Distance metric to use for silhouette calculation.
figsize : tuple, default (12, 8)
Figure size.
colors : list, optional
List of colors for clusters. If None, default palette is used.
show_avg_line : bool, default True
Show vertical line for average silhouette score.
show_cluster_centers : bool
Ignored in this version. Future version will support it via `_plot_cluster_scatter`.
cluster_centers : array-like, optional
Ignored in this version.
title : str, optional
Custom title for the plot.
return_scores : bool, default False
If True, also return dict of silhouette scores and stats.
Returns
-------
matplotlib.figure.Figure or tuple
Plot only, or (plot, scores_dict) if return_scores=True.
"""
X = np.asarray(X)
cluster_labels = np.asarray(cluster_labels)
if X.shape[0] != len(cluster_labels):
raise ValueError("X and cluster_labels must have the same number of samples")
if len(np.unique(cluster_labels)) < 2:
raise ValueError("Need at least 2 clusters for silhouette analysis")
sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)
avg_silhouette_score = silhouette_score(X, cluster_labels, metric=metric)
unique_labels = np.unique(cluster_labels)
n_clusters = len(unique_labels)
if colors is None:
colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))
elif len(colors) < n_clusters:
warnings.warn(f"Not enough colors provided ({len(colors)}), need at least {n_clusters}")
colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
y_lower = 10
cluster_stats = {}
for i, cluster_label in enumerate(unique_labels):
cluster_silhouette_values = sample_silhouette_values[cluster_labels == cluster_label]
cluster_silhouette_values.sort()
cluster_size = len(cluster_silhouette_values)
cluster_stats[cluster_label] = {
'size': cluster_size,
'avg_score': np.mean(cluster_silhouette_values),
'min_score': np.min(cluster_silhouette_values),
'max_score': np.max(cluster_silhouette_values),
'y_range': (y_lower, y_lower + cluster_size)
}
y_upper = y_lower + cluster_size
color = colors[i]
ax1.fill_betweenx(np.arange(y_lower, y_upper),
0, cluster_silhouette_values,
facecolor=color, edgecolor=color, alpha=0.7)
ax1.text(-0.05, y_lower + 0.5 * cluster_size,
f'Cluster {cluster_label}\n(n={cluster_size})',
va='center',
bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
y_lower = y_upper + 10
ax1.set_xlabel('Silhouette Coefficient Values', fontsize=12)
ax1.set_ylabel('Cluster Index', fontsize=12)
if show_avg_line:
ax1.axvline(x=avg_silhouette_score, color="red", linestyle="--",
linewidth=2, label=f'Avg Score: {avg_silhouette_score:.3f}')
ax1.legend(loc='upper right')
if title is None:
title = f'Silhouette Analysis for {n_clusters} Clusters\nAvg Score: {avg_silhouette_score:.3f}'
ax1.set_title(title, fontsize=14)
interpretation = _get_silhouette_interpretation(avg_silhouette_score)
ax1.text(0.02, 0.98, interpretation, transform=ax1.transAxes,
va='top', fontsize=10,
bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))
ax1.set_xlim([-0.1, 1])
ax1.set_ylim([0, len(cluster_labels) + (n_clusters + 1) * 10])
ax1.set_yticks([])
ax1.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
if return_scores:
return fig, {
'avg_score': avg_silhouette_score,
'sample_scores': sample_silhouette_values,
'cluster_stats': cluster_stats,
'n_clusters': n_clusters,
'interpretation': interpretation
}
return fig
def _get_silhouette_interpretation(avg_score: float) -> str:
"""Interpret silhouette average score with qualitative rating."""
if avg_score >= 0.7:
return "Excellent clustering\n(Score β₯ 0.7)"
elif avg_score >= 0.5:
return "Good clustering\n(0.5 β€ Score < 0.7)"
elif avg_score >= 0.25:
return "Fair clustering\n(0.25 β€ Score < 0.5)"
else:
return "Poor clustering\n(Score < 0.25)"
β To Do
- Add
plot_silhouette_analysis
to unsupervised module - Include unit tests for:
- Less than 2 clusters
- Mismatched input dimensions
return_scores=True
- Update documentation (
README.md
, docstrings)
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request