Skip to content

Feature: Add plot_silhouette_analysis for visual evaluation of clusteringΒ #65

@idanmoradarthas

Description

@idanmoradarthas

βœ… Feature Request: Add plot_silhouette_analysis to unsupervised module for visualizing clustering quality

πŸ“Œ Summary

Add a new method plot_silhouette_analysis to the unsupervised module of the DataScienceUtils package. This utility enables comprehensive silhouette analysis of clustering results, allowing users to visually assess cluster cohesion and separation using silhouette scores.


🎯 Why This Feature?

Clustering is inherently unsupervised β€” and without true labels, users often struggle to determine whether clusters make sense. Silhouette analysis provides a well-established visual metric that shows:

  • How close each sample is to its own cluster vs others (cohesion vs separation)
  • Whether the current number of clusters is appropriate
  • Which clusters are well-formed and which are problematic

This function will accelerate exploratory clustering analysis and guide better hyperparameter tuning.


🧩 Target Module

Place the new method under:

datascienceutils/unsupervised/plot_silhouette_analysis

πŸ› οΈ Full Implementation

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_samples, silhouette_score
from typing import Optional, Union, Tuple
import warnings

def plot_silhouette_analysis(
    X: Union[np.ndarray, 'pd.DataFrame'],
    cluster_labels: np.ndarray,
    metric: str = 'euclidean',
    figsize: Tuple[int, int] = (12, 8),
    colors: Optional[list] = None,
    show_avg_line: bool = True,
    show_cluster_centers: bool = False,  # accepted for API consistency, but ignored in this issue
    cluster_centers: Optional[np.ndarray] = None,  # ignored
    title: Optional[str] = None,
    return_scores: bool = False
) -> Union[plt.Figure, Tuple[plt.Figure, dict]]:
    """
    Create a silhouette analysis plot for clustering results.

    Parameters
    ----------
    X : array-like or DataFrame of shape (n_samples, n_features)
        Input data used for clustering.
    cluster_labels : array-like of shape (n_samples,)
        Cluster labels for each sample.
    metric : str, default 'euclidean'
        Distance metric to use for silhouette calculation.
    figsize : tuple, default (12, 8)
        Figure size.
    colors : list, optional
        List of colors for clusters. If None, default palette is used.
    show_avg_line : bool, default True
        Show vertical line for average silhouette score.
    show_cluster_centers : bool
        Ignored in this version. Future version will support it via `_plot_cluster_scatter`.
    cluster_centers : array-like, optional
        Ignored in this version.
    title : str, optional
        Custom title for the plot.
    return_scores : bool, default False
        If True, also return dict of silhouette scores and stats.

    Returns
    -------
    matplotlib.figure.Figure or tuple
        Plot only, or (plot, scores_dict) if return_scores=True.
    """

    X = np.asarray(X)
    cluster_labels = np.asarray(cluster_labels)

    if X.shape[0] != len(cluster_labels):
        raise ValueError("X and cluster_labels must have the same number of samples")
    if len(np.unique(cluster_labels)) < 2:
        raise ValueError("Need at least 2 clusters for silhouette analysis")

    sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric)
    avg_silhouette_score = silhouette_score(X, cluster_labels, metric=metric)

    unique_labels = np.unique(cluster_labels)
    n_clusters = len(unique_labels)

    if colors is None:
        colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))
    elif len(colors) < n_clusters:
        warnings.warn(f"Not enough colors provided ({len(colors)}), need at least {n_clusters}")
        colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))

    fig, ax1 = plt.subplots(1, 1, figsize=figsize)

    y_lower = 10
    cluster_stats = {}

    for i, cluster_label in enumerate(unique_labels):
        cluster_silhouette_values = sample_silhouette_values[cluster_labels == cluster_label]
        cluster_silhouette_values.sort()
        cluster_size = len(cluster_silhouette_values)

        cluster_stats[cluster_label] = {
            'size': cluster_size,
            'avg_score': np.mean(cluster_silhouette_values),
            'min_score': np.min(cluster_silhouette_values),
            'max_score': np.max(cluster_silhouette_values),
            'y_range': (y_lower, y_lower + cluster_size)
        }

        y_upper = y_lower + cluster_size
        color = colors[i]

        ax1.fill_betweenx(np.arange(y_lower, y_upper),
                          0, cluster_silhouette_values,
                          facecolor=color, edgecolor=color, alpha=0.7)

        ax1.text(-0.05, y_lower + 0.5 * cluster_size,
                 f'Cluster {cluster_label}\n(n={cluster_size})',
                 va='center',
                 bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))

        y_lower = y_upper + 10

    ax1.set_xlabel('Silhouette Coefficient Values', fontsize=12)
    ax1.set_ylabel('Cluster Index', fontsize=12)

    if show_avg_line:
        ax1.axvline(x=avg_silhouette_score, color="red", linestyle="--",
                    linewidth=2, label=f'Avg Score: {avg_silhouette_score:.3f}')
        ax1.legend(loc='upper right')

    if title is None:
        title = f'Silhouette Analysis for {n_clusters} Clusters\nAvg Score: {avg_silhouette_score:.3f}'
    ax1.set_title(title, fontsize=14)

    interpretation = _get_silhouette_interpretation(avg_silhouette_score)
    ax1.text(0.02, 0.98, interpretation, transform=ax1.transAxes,
             va='top', fontsize=10,
             bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.7))

    ax1.set_xlim([-0.1, 1])
    ax1.set_ylim([0, len(cluster_labels) + (n_clusters + 1) * 10])
    ax1.set_yticks([])
    ax1.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()

    if return_scores:
        return fig, {
            'avg_score': avg_silhouette_score,
            'sample_scores': sample_silhouette_values,
            'cluster_stats': cluster_stats,
            'n_clusters': n_clusters,
            'interpretation': interpretation
        }

    return fig


def _get_silhouette_interpretation(avg_score: float) -> str:
    """Interpret silhouette average score with qualitative rating."""
    if avg_score >= 0.7:
        return "Excellent clustering\n(Score β‰₯ 0.7)"
    elif avg_score >= 0.5:
        return "Good clustering\n(0.5 ≀ Score < 0.7)"
    elif avg_score >= 0.25:
        return "Fair clustering\n(0.25 ≀ Score < 0.5)"
    else:
        return "Poor clustering\n(Score < 0.25)"

βœ… To Do

  • Add plot_silhouette_analysis to unsupervised module
  • Include unit tests for:
    • Less than 2 clusters
    • Mismatched input dimensions
    • return_scores=True
  • Update documentation (README.md, docstrings)

Metadata

Metadata

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions