Skip to content

Commit 91562a7

Browse files
Automatic pre-commit fixes
1 parent 50ca0e8 commit 91562a7

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

examples/clustering/sklearn_clustering_with_aeon_distances.ipynb

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@
119119
"from sklearn.cluster import AgglomerativeClustering\n",
120120
"\n",
121121
"# Perform Agglomerative Clustering\n",
122-
"agg_clustering = AgglomerativeClustering(n_clusters=2, metric=\"precomputed\", linkage=\"average\")\n",
122+
"agg_clustering = AgglomerativeClustering(\n",
123+
" n_clusters=2, metric=\"precomputed\", linkage=\"average\"\n",
124+
")\n",
123125
"labels = agg_clustering.fit_predict(distance_matrix)\n",
124126
"\n",
125127
"# Visualize the clustering results\n",
@@ -181,7 +183,7 @@
181183
"plt.figure(figsize=(10, 6))\n",
182184
"for label in np.unique(dbscan_labels):\n",
183185
" cluster_data = X[np.where(dbscan_labels == label)] # Fix indexing\n",
184-
" \n",
186+
"\n",
185187
" if label == -1:\n",
186188
" plt.plot(cluster_data.mean(axis=0), label=\"Noise\", linestyle=\"--\", linewidth=2)\n",
187189
" else:\n",
@@ -192,7 +194,7 @@
192194
"plt.ylabel(\"Mean Value\")\n",
193195
"plt.legend(loc=\"upper right\", fontsize=\"small\")\n",
194196
"plt.grid(True)\n",
195-
"plt.show()\n"
197+
"plt.show()"
196198
]
197199
},
198200
{
@@ -239,22 +241,27 @@
239241
"\n",
240242
" # Ensure correct shape for plotting\n",
241243
" cluster_data = np.squeeze(cluster_data)\n",
242-
" if cluster_data.ndim == 1: \n",
244+
" if cluster_data.ndim == 1:\n",
243245
" cluster_data = cluster_data[:, np.newaxis] # Convert to 2D if needed\n",
244246
"\n",
245247
" # Compute mean representation of each cluster\n",
246-
" cluster_mean = cluster_data.mean(axis=0) \n",
248+
" cluster_mean = cluster_data.mean(axis=0)\n",
247249
"\n",
248250
" # Plot noise separately\n",
249251
" if label == -1:\n",
250252
" plt.plot(cluster_mean, linestyle=\"--\", color=\"gray\", alpha=0.5, label=\"Noise\")\n",
251253
" else:\n",
252-
" plt.plot(cluster_mean, color=colors(label % colors.N), alpha=0.7, label=f\"Cluster {label}\")\n",
254+
" plt.plot(\n",
255+
" cluster_mean,\n",
256+
" color=colors(label % colors.N),\n",
257+
" alpha=0.7,\n",
258+
" label=f\"Cluster {label}\",\n",
259+
" )\n",
253260
"\n",
254261
"plt.title(\"OPTICS Clustering with DTW Distance\")\n",
255262
"plt.legend()\n",
256263
"plt.grid(True, linestyle=\"--\", alpha=0.5) # Light grid for better readability\n",
257-
"plt.show()\n"
264+
"plt.show()"
258265
]
259266
},
260267
{
@@ -287,15 +294,24 @@
287294
"from sklearn.cluster import SpectralClustering\n",
288295
"from sklearn.metrics import pairwise_distances\n",
289296
"\n",
290-
"X = np.vstack((np.random.normal(loc=[2, 2], scale=0.5, size=(50, 2)), \n",
291-
" np.random.normal(loc=[5, 5], scale=0.5, size=(50, 2))))\n",
292-
"distance_matrix = pairwise_distances(X, metric='euclidean')\n",
297+
"X = np.vstack(\n",
298+
" (\n",
299+
" np.random.normal(loc=[2, 2], scale=0.5, size=(50, 2)),\n",
300+
" np.random.normal(loc=[5, 5], scale=0.5, size=(50, 2)),\n",
301+
" )\n",
302+
")\n",
303+
"distance_matrix = pairwise_distances(X, metric=\"euclidean\")\n",
293304
"inverse_distance_matrix = 1 - (distance_matrix / distance_matrix.max())\n",
294305
"spectral = SpectralClustering(n_clusters=2, affinity=\"precomputed\", random_state=42)\n",
295306
"spectral_labels = spectral.fit_predict(inverse_distance_matrix)\n",
296307
"plt.figure(figsize=(10, 6))\n",
297308
"for label in np.unique(spectral_labels):\n",
298-
" plt.scatter(X[spectral_labels == label, 0], X[spectral_labels == label, 1], label=f\"Cluster {label}\", alpha=0.7)\n",
309+
" plt.scatter(\n",
310+
" X[spectral_labels == label, 0],\n",
311+
" X[spectral_labels == label, 1],\n",
312+
" label=f\"Cluster {label}\",\n",
313+
" alpha=0.7,\n",
314+
" )\n",
299315
"plt.title(\"Spectral Clustering with Normalized Similarity Matrix\")\n",
300316
"plt.legend()\n",
301317
"plt.show()"

0 commit comments

Comments
 (0)