Skip to content

Commit f510fa4

Browse files
committed
Code clean up
1 parent fadd04a commit f510fa4

File tree

9 files changed

+57
-132
lines changed

9 files changed

+57
-132
lines changed

.github/workflows/test.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ jobs:
2323
fail-fast: false
2424
matrix:
2525
include:
26-
- os: ubuntu-latest
27-
python: "3.9"
2826
- os: ubuntu-latest
2927
python: "3.10"
3028
- os: ubuntu-latest

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.10
1+
3.12

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Please refer to the [documentation][link-docs].
2020

2121
## Installation
2222

23-
You need to have Python 3.9 or newer installed on your system. If you don't have
23+
You need to have Python 3.10 or newer installed on your system. If you don't have
2424
Python installed, we recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge).
2525

2626
There are several alternative options to install pyclustree:

docs/source/example.ipynb

Lines changed: 5 additions & 42 deletions
Large diffs are not rendered by default.

docs/source/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
:::{card}
44
:class-card: sd-bg-warning
55
:class-body: sd-bg-text-warning
6-
**pyclustree** only supports Python versions greater than or equal to **3.9**.
6+
**pyclustree** only supports Python versions greater than or equal to **3.10**.
77
:::
88

99
## Installation Options
@@ -35,7 +35,7 @@ Install `pyclustree` from source:
3535
# Clone repo
3636
git clone --depth 1 https://github.com/complextissue/pyclustree.git
3737
cd pyclustree
38-
pyenv global 3.9
38+
pyenv global 3.10
3939
make create-venv
4040
source .venv/bin/activate
4141
make install

docs/source/start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Please refer to the [documentation][link-docs].
1919

2020
## Installation
2121

22-
You need to have Python 3.9 or newer installed on your system. If you don't have
22+
You need to have Python 3.10 or newer installed on your system. If you don't have
2323
Python installed, we recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge).
2424

2525
There are several alternative options to install pyclustree:

pyclustree/_clustree.py

Lines changed: 39 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1+
from collections.abc import Callable
12
from logging import warning
2-
from typing import Callable, Literal, Optional, Union
33

44
import networkx as nx
55
import numpy as np
66
from anndata import AnnData
77
from matplotlib import pyplot as plt
88
from matplotlib.colors import Colormap
9-
from numpy.typing import ArrayLike, NDArray
9+
from numpy.typing import NDArray
1010

11-
from ._utils import calculate_clustering_score, calculate_transition_matrix, order_unique_clusters
11+
from ._utils import order_unique_clusters, transition_matrix
1212

1313

1414
def clustree(
1515
adata: AnnData,
1616
cluster_keys: list[str],
17-
title: Optional[str] = None,
18-
scatter_reference: Optional[str] = None,
19-
node_colormap: Union[list[Colormap], Colormap, str] = "tab20",
20-
node_color_gene: Optional[str] = None,
17+
title: str | None = None,
18+
scatter_reference: str | None = None,
19+
node_colormap: list[Colormap] | Colormap | str = "tab20",
20+
node_color_gene: str | None = None,
2121
node_color_gene_use_raw: bool = True,
22-
node_color_gene_transformer: Optional[Callable] = None,
22+
node_color_gene_transformer: Callable | None = None,
2323
node_size_range: tuple[float, float] = (100, 1000),
2424
edge_width_range: tuple[float, float] = (0.5, 5.0),
2525
edge_weight_threshold: float = 0.0,
@@ -29,14 +29,7 @@ def clustree(
2929
show_colorbar: bool = False,
3030
show_fraction: bool = False,
3131
show_cluster_keys: bool = True,
32-
score_clustering: Optional[
33-
Union[
34-
Literal["silhouette", "davies_bouldin", "calinski_harabasz"],
35-
Callable[[ArrayLike, ArrayLike], float],
36-
]
37-
] = None,
38-
score_basis: Literal["X", "raw", "pca"] = "pca",
39-
graph_plot_kwargs: Optional[dict] = None,
32+
graph_plot_kwargs: dict | None = None,
4033
) -> plt.Figure:
4134
"""Create a hierarchical clustering tree visualization to compare different clustering resolutions.
4235
@@ -125,49 +118,58 @@ def clustree(
125118
node_color_gene in adata.obs.columns or node_color_gene in adata.var_names
126119
), "The provided gene should be present in the adata.var_names/adata.raw.var_names or adata.obs."
127120

128-
if scatter_reference is not None and score_clustering is not None:
129-
raise ValueError("Cluster scoring is not supported for scatter plotting.")
130-
131121
if isinstance(node_colormap, str):
132122
node_colormap = plt.get_cmap(node_colormap)
133123

134124
df_cluster_assignments = adata.obs[cluster_keys]
135125

136126
transition_matrices = [
137-
calculate_transition_matrix(
127+
transition_matrix(
138128
df_cluster_assignments[cluster_keys[i]],
139129
df_cluster_assignments[cluster_keys[i + 1]],
140130
)
141131
for i in range(len(cluster_keys) - 1)
142132
]
143133

134+
def cluster_sort_key(cluster):
135+
try:
136+
return int(cluster)
137+
except (ValueError, TypeError):
138+
return str(cluster)
139+
144140
unique_clusters = [np.unique(df_cluster_assignments[key]).tolist() for key in cluster_keys]
145-
unique_clusters_sorted = [sorted(unique_clusters_level) for unique_clusters_level in unique_clusters]
141+
unique_clusters = [sorted(unique_clusters_level, key=cluster_sort_key) for unique_clusters_level in unique_clusters]
146142

147143
if order_clusters:
148144
unique_clusters = order_unique_clusters(unique_clusters, transition_matrices)
149145

150146
# Create the Graph
151147
G: nx.Graph = nx.DiGraph()
152148

149+
layers: dict[str, list[str]] = {}
150+
153151
# Add the nodes and store cluster info directly in the graph
154152
for i, key in enumerate(cluster_keys):
155-
for cluster in reversed(unique_clusters[i]):
153+
for cluster in unique_clusters[i]:
156154
node_name = f"{key}_{cluster}"
157155
cluster_cells = df_cluster_assignments[key] == cluster
158156
node_size = np.sum(cluster_cells)
159157

160158
# Store info directly in the node
161159
G.add_node(
162160
node_name,
163-
layer=len(cluster_keys) - i,
164161
level=i,
165162
key=key,
166163
cluster=cluster,
167164
size=node_size,
168165
cells=cluster_cells,
169166
)
170167

168+
layer_key = str(len(cluster_keys) - i)
169+
if layer_key not in layers:
170+
layers[layer_key] = []
171+
layers[layer_key].append(node_name)
172+
171173
# Compute node sizes scaled to the desired range
172174
max_size = max(G.nodes[node]["size"] for node in G.nodes)
173175
for node in G.nodes:
@@ -177,13 +179,13 @@ def clustree(
177179
)
178180

179181
# Add edges between each level and the next level
180-
for i, transition_matrix in enumerate(transition_matrices):
181-
for parent_cluster in transition_matrix.index:
182+
for i, transition in enumerate(transition_matrices):
183+
for parent_cluster in transition.index:
182184
parent_node = f"{cluster_keys[i]}_{parent_cluster}"
183185

184-
for child_cluster in transition_matrix.columns:
186+
for child_cluster in transition.columns:
185187
child_node = f"{cluster_keys[i + 1]}_{child_cluster}"
186-
weight = transition_matrix.loc[parent_cluster, child_cluster]
188+
weight = transition.loc[parent_cluster, child_cluster]
187189

188190
if weight > edge_weight_threshold:
189191
G.add_edge(parent_node, child_node, weight=weight)
@@ -206,7 +208,7 @@ def clustree(
206208
plt.cm.get_cmap(node_colormap[level]) if isinstance(node_colormap[level], str) else node_colormap[level]
207209
)
208210
norm_level = plt.Normalize(vmin=0, vmax=len(unique_clusters[level]) - 1)
209-
color_idx = unique_clusters_sorted[level].index(cluster)
211+
color_idx = unique_clusters[level].index(cluster)
210212
G.nodes[node]["color"] = cmap_level(norm_level(color_idx))
211213

212214
elif node_color_gene is not None:
@@ -260,7 +262,12 @@ def clustree(
260262
cells = G.nodes[node]["cells"]
261263
node_positions[node] = (np.median(x_positions[cells]), np.median(y_positions[cells]))
262264
else:
263-
node_positions = nx.multipartite_layout(G, align="horizontal", subset_key="layer", scale=1.0)
265+
node_positions = nx.multipartite_layout(
266+
G,
267+
align="horizontal",
268+
subset_key=layers, # type: ignore
269+
scale=1.0,
270+
)
264271

265272
# Prepare edge widths scaled to desired range
266273
edge_weights = [G.edges[edge]["weight"] for edge in G.edges]
@@ -358,8 +365,8 @@ def clustree(
358365
warning("Colorbars are not supported when providing a list of colormaps. Ignoring the argument.")
359366

360367
# Calculate positions for cluster keys and scores
361-
need_level_positions = score_clustering is not None or (show_cluster_keys and scatter_reference is not None)
362-
y_positions_levels: Union[NDArray[np.float64], list[float]]
368+
need_level_positions = show_cluster_keys and scatter_reference is not None
369+
y_positions_levels: NDArray[np.float64] | list[float]
363370
if need_level_positions:
364371
y_positions_levels = np.linspace(
365372
ax.get_ylim()[1],
@@ -372,7 +379,7 @@ def clustree(
372379
x_min = ax.get_xlim()[0] if scatter_reference is None else ax.get_xlim()[1] + 2
373380

374381
# Determine y positions and colors
375-
facecolor: Union[list[str], list[tuple[float, float, float, float]]]
382+
facecolor: list[str] | list[tuple[float, float, float, float]]
376383
if scatter_reference is None:
377384
# Use node positions for y-coordinates in hierarchical layout
378385
level_nodes: dict[int, list[str]] = {}
@@ -416,46 +423,4 @@ def clustree(
416423
if title is not None:
417424
ax.set_title(title, fontsize=16, fontweight="bold", pad=10)
418425

419-
# Add clustering scores
420-
if score_clustering is not None:
421-
# Calculate scores
422-
scores = [calculate_clustering_score(adata, key, score_clustering, score_basis) for key in cluster_keys]
423-
424-
# Map score names for display
425-
score_name_map = {
426-
"silhouette": "Silhouette score",
427-
"calinski_harabasz": "Calinski and Harabasz score",
428-
"davies_bouldin": "Davies-Bouldin score",
429-
}
430-
431-
# Add score title
432-
x_max = ax.get_xlim()[1] + 0.5
433-
score_title = score_name_map[score_clustering] if isinstance(score_clustering, str) else "Clustering score"
434-
435-
ax.text(
436-
x_max,
437-
y=ax.get_ylim()[1],
438-
s=score_title,
439-
fontsize=12,
440-
color="black",
441-
ha="center",
442-
va="center",
443-
)
444-
445-
# Display scores
446-
for i, score in enumerate(scores):
447-
ax.text(
448-
x_max,
449-
y=y_positions_levels[i],
450-
s=f"{score:.2f}",
451-
fontsize=12,
452-
ha="center",
453-
va="center",
454-
bbox={
455-
"boxstyle": "round",
456-
"facecolor": facecolor[i],
457-
"edgecolor": "black",
458-
},
459-
)
460-
461426
return fig

pyclustree/_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Callable, Literal, Optional, Union
1+
from collections.abc import Callable
2+
from typing import Literal
23

34
import pandas as pd
45
from anndata import AnnData
56
from numpy.typing import ArrayLike, NDArray
67

78

8-
def calculate_transition_matrix(
9+
def transition_matrix(
910
cluster_a: pd.Series,
1011
cluster_b: pd.Series,
1112
) -> pd.DataFrame:
@@ -104,10 +105,8 @@ def order_unique_clusters(
104105
def calculate_clustering_score(
105106
adata: AnnData,
106107
cluster_key: str,
107-
score_method: Union[
108-
Literal["silhouette", "davies_bouldin", "calinski_harabasz"],
109-
Callable[[ArrayLike, ArrayLike], float],
110-
],
108+
score_method: Literal["silhouette", "davies_bouldin", "calinski_harabasz"]
109+
| Callable[[ArrayLike, ArrayLike], float],
111110
score_basis: Literal["X", "raw", "pca"] = "pca",
112111
) -> float:
113112
"""Calculate clustering score using specified method and data basis.
@@ -127,7 +126,7 @@ def calculate_clustering_score(
127126
ValueError: If an invalid score method or basis is provided.
128127
"""
129128
# Assign basis for scoring
130-
basis: Optional[NDArray] = None
129+
basis: NDArray | None = None
131130

132131
if score_basis == "X":
133132
basis = adata.X.copy()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ maintainers = [
1414
{ name = "Malte Kuehl", email = "malte.kuehl@clin.au.dk" },
1515
]
1616
authors = [ { name = "Malte Hellmig" }, { name = "Malte Kuehl" } ]
17-
requires-python = ">=3.9"
17+
requires-python = ">=3.10"
1818
classifiers = [
1919
"Development Status :: 3 - Alpha",
2020
"Intended Audience :: Healthcare Industry",
@@ -23,7 +23,6 @@ classifiers = [
2323
"Natural Language :: English",
2424
"Operating System :: OS Independent",
2525
"Programming Language :: Python :: 3 :: Only",
26-
"Programming Language :: Python :: 3.9",
2726
"Programming Language :: Python :: 3.10",
2827
"Programming Language :: Python :: 3.11",
2928
"Programming Language :: Python :: 3.12",
@@ -51,6 +50,7 @@ optional-dependencies.dev = [
5150
"myst-parser",
5251
"nbsphinx",
5352
"pandas",
53+
"pandas-stubs",
5454
"pandoc",
5555
"pre-commit",
5656
"pytest",

0 commit comments

Comments
 (0)