1
+ from collections .abc import Callable
1
2
from logging import warning
2
- from typing import Callable , Literal , Optional , Union
3
3
4
4
import networkx as nx
5
5
import numpy as np
6
6
from anndata import AnnData
7
7
from matplotlib import pyplot as plt
8
8
from matplotlib .colors import Colormap
9
- from numpy .typing import ArrayLike , NDArray
9
+ from numpy .typing import NDArray
10
10
11
- from ._utils import calculate_clustering_score , calculate_transition_matrix , order_unique_clusters
11
+ from ._utils import order_unique_clusters , transition_matrix
12
12
13
13
14
14
def clustree (
15
15
adata : AnnData ,
16
16
cluster_keys : list [str ],
17
- title : Optional [ str ] = None ,
18
- scatter_reference : Optional [ str ] = None ,
19
- node_colormap : Union [ list [Colormap ], Colormap , str ] = "tab20" ,
20
- node_color_gene : Optional [ str ] = None ,
17
+ title : str | None = None ,
18
+ scatter_reference : str | None = None ,
19
+ node_colormap : list [Colormap ] | Colormap | str = "tab20" ,
20
+ node_color_gene : str | None = None ,
21
21
node_color_gene_use_raw : bool = True ,
22
- node_color_gene_transformer : Optional [ Callable ] = None ,
22
+ node_color_gene_transformer : Callable | None = None ,
23
23
node_size_range : tuple [float , float ] = (100 , 1000 ),
24
24
edge_width_range : tuple [float , float ] = (0.5 , 5.0 ),
25
25
edge_weight_threshold : float = 0.0 ,
@@ -29,14 +29,7 @@ def clustree(
29
29
show_colorbar : bool = False ,
30
30
show_fraction : bool = False ,
31
31
show_cluster_keys : bool = True ,
32
- score_clustering : Optional [
33
- Union [
34
- Literal ["silhouette" , "davies_bouldin" , "calinski_harabasz" ],
35
- Callable [[ArrayLike , ArrayLike ], float ],
36
- ]
37
- ] = None ,
38
- score_basis : Literal ["X" , "raw" , "pca" ] = "pca" ,
39
- graph_plot_kwargs : Optional [dict ] = None ,
32
+ graph_plot_kwargs : dict | None = None ,
40
33
) -> plt .Figure :
41
34
"""Create a hierarchical clustering tree visualization to compare different clustering resolutions.
42
35
@@ -125,49 +118,58 @@ def clustree(
125
118
node_color_gene in adata .obs .columns or node_color_gene in adata .var_names
126
119
), "The provided gene should be present in the adata.var_names/adata.raw.var_names or adata.obs."
127
120
128
- if scatter_reference is not None and score_clustering is not None :
129
- raise ValueError ("Cluster scoring is not supported for scatter plotting." )
130
-
131
121
if isinstance (node_colormap , str ):
132
122
node_colormap = plt .get_cmap (node_colormap )
133
123
134
124
df_cluster_assignments = adata .obs [cluster_keys ]
135
125
136
126
transition_matrices = [
137
- calculate_transition_matrix (
127
+ transition_matrix (
138
128
df_cluster_assignments [cluster_keys [i ]],
139
129
df_cluster_assignments [cluster_keys [i + 1 ]],
140
130
)
141
131
for i in range (len (cluster_keys ) - 1 )
142
132
]
143
133
134
+ def cluster_sort_key (cluster ):
135
+ try :
136
+ return int (cluster )
137
+ except (ValueError , TypeError ):
138
+ return str (cluster )
139
+
144
140
unique_clusters = [np .unique (df_cluster_assignments [key ]).tolist () for key in cluster_keys ]
145
- unique_clusters_sorted = [sorted (unique_clusters_level ) for unique_clusters_level in unique_clusters ]
141
+ unique_clusters = [sorted (unique_clusters_level , key = cluster_sort_key ) for unique_clusters_level in unique_clusters ]
146
142
147
143
if order_clusters :
148
144
unique_clusters = order_unique_clusters (unique_clusters , transition_matrices )
149
145
150
146
# Create the Graph
151
147
G : nx .Graph = nx .DiGraph ()
152
148
149
+ layers : dict [str , list [str ]] = {}
150
+
153
151
# Add the nodes and store cluster info directly in the graph
154
152
for i , key in enumerate (cluster_keys ):
155
- for cluster in reversed ( unique_clusters [i ]) :
153
+ for cluster in unique_clusters [i ]:
156
154
node_name = f"{ key } _{ cluster } "
157
155
cluster_cells = df_cluster_assignments [key ] == cluster
158
156
node_size = np .sum (cluster_cells )
159
157
160
158
# Store info directly in the node
161
159
G .add_node (
162
160
node_name ,
163
- layer = len (cluster_keys ) - i ,
164
161
level = i ,
165
162
key = key ,
166
163
cluster = cluster ,
167
164
size = node_size ,
168
165
cells = cluster_cells ,
169
166
)
170
167
168
+ layer_key = str (len (cluster_keys ) - i )
169
+ if layer_key not in layers :
170
+ layers [layer_key ] = []
171
+ layers [layer_key ].append (node_name )
172
+
171
173
# Compute node sizes scaled to the desired range
172
174
max_size = max (G .nodes [node ]["size" ] for node in G .nodes )
173
175
for node in G .nodes :
@@ -177,13 +179,13 @@ def clustree(
177
179
)
178
180
179
181
# Add edges between each level and the next level
180
- for i , transition_matrix in enumerate (transition_matrices ):
181
- for parent_cluster in transition_matrix .index :
182
+ for i , transition in enumerate (transition_matrices ):
183
+ for parent_cluster in transition .index :
182
184
parent_node = f"{ cluster_keys [i ]} _{ parent_cluster } "
183
185
184
- for child_cluster in transition_matrix .columns :
186
+ for child_cluster in transition .columns :
185
187
child_node = f"{ cluster_keys [i + 1 ]} _{ child_cluster } "
186
- weight = transition_matrix .loc [parent_cluster , child_cluster ]
188
+ weight = transition .loc [parent_cluster , child_cluster ]
187
189
188
190
if weight > edge_weight_threshold :
189
191
G .add_edge (parent_node , child_node , weight = weight )
@@ -206,7 +208,7 @@ def clustree(
206
208
plt .cm .get_cmap (node_colormap [level ]) if isinstance (node_colormap [level ], str ) else node_colormap [level ]
207
209
)
208
210
norm_level = plt .Normalize (vmin = 0 , vmax = len (unique_clusters [level ]) - 1 )
209
- color_idx = unique_clusters_sorted [level ].index (cluster )
211
+ color_idx = unique_clusters [level ].index (cluster )
210
212
G .nodes [node ]["color" ] = cmap_level (norm_level (color_idx ))
211
213
212
214
elif node_color_gene is not None :
@@ -260,7 +262,12 @@ def clustree(
260
262
cells = G .nodes [node ]["cells" ]
261
263
node_positions [node ] = (np .median (x_positions [cells ]), np .median (y_positions [cells ]))
262
264
else :
263
- node_positions = nx .multipartite_layout (G , align = "horizontal" , subset_key = "layer" , scale = 1.0 )
265
+ node_positions = nx .multipartite_layout (
266
+ G ,
267
+ align = "horizontal" ,
268
+ subset_key = layers , # type: ignore
269
+ scale = 1.0 ,
270
+ )
264
271
265
272
# Prepare edge widths scaled to desired range
266
273
edge_weights = [G .edges [edge ]["weight" ] for edge in G .edges ]
@@ -358,8 +365,8 @@ def clustree(
358
365
warning ("Colorbars are not supported when providing a list of colormaps. Ignoring the argument." )
359
366
360
367
# Calculate positions for cluster keys and scores
361
- need_level_positions = score_clustering is not None or ( show_cluster_keys and scatter_reference is not None )
362
- y_positions_levels : Union [ NDArray [np .float64 ], list [float ] ]
368
+ need_level_positions = show_cluster_keys and scatter_reference is not None
369
+ y_positions_levels : NDArray [np .float64 ] | list [float ]
363
370
if need_level_positions :
364
371
y_positions_levels = np .linspace (
365
372
ax .get_ylim ()[1 ],
@@ -372,7 +379,7 @@ def clustree(
372
379
x_min = ax .get_xlim ()[0 ] if scatter_reference is None else ax .get_xlim ()[1 ] + 2
373
380
374
381
# Determine y positions and colors
375
- facecolor : Union [ list [str ], list [tuple [float , float , float , float ] ]]
382
+ facecolor : list [str ] | list [tuple [float , float , float , float ]]
376
383
if scatter_reference is None :
377
384
# Use node positions for y-coordinates in hierarchical layout
378
385
level_nodes : dict [int , list [str ]] = {}
@@ -416,46 +423,4 @@ def clustree(
416
423
if title is not None :
417
424
ax .set_title (title , fontsize = 16 , fontweight = "bold" , pad = 10 )
418
425
419
- # Add clustering scores
420
- if score_clustering is not None :
421
- # Calculate scores
422
- scores = [calculate_clustering_score (adata , key , score_clustering , score_basis ) for key in cluster_keys ]
423
-
424
- # Map score names for display
425
- score_name_map = {
426
- "silhouette" : "Silhouette score" ,
427
- "calinski_harabasz" : "Calinski and Harabasz score" ,
428
- "davies_bouldin" : "Davies-Bouldin score" ,
429
- }
430
-
431
- # Add score title
432
- x_max = ax .get_xlim ()[1 ] + 0.5
433
- score_title = score_name_map [score_clustering ] if isinstance (score_clustering , str ) else "Clustering score"
434
-
435
- ax .text (
436
- x_max ,
437
- y = ax .get_ylim ()[1 ],
438
- s = score_title ,
439
- fontsize = 12 ,
440
- color = "black" ,
441
- ha = "center" ,
442
- va = "center" ,
443
- )
444
-
445
- # Display scores
446
- for i , score in enumerate (scores ):
447
- ax .text (
448
- x_max ,
449
- y = y_positions_levels [i ],
450
- s = f"{ score :.2f} " ,
451
- fontsize = 12 ,
452
- ha = "center" ,
453
- va = "center" ,
454
- bbox = {
455
- "boxstyle" : "round" ,
456
- "facecolor" : facecolor [i ],
457
- "edgecolor" : "black" ,
458
- },
459
- )
460
-
461
426
return fig
0 commit comments