10
10
from cuml .cluster import approximate_predict
11
11
from sentence_transformers import SentenceTransformer
12
12
from sklearn .metrics .pairwise import cosine_similarity
13
+ from sklearn .metrics import silhouette_score
13
14
from tqdm import tqdm
14
15
15
16
from .models .clip_encoder import CLIPEncoder
@@ -158,6 +159,10 @@ def __init__(self,
158
159
self .topic_descriptions = None
159
160
self .topic_alterations = {}
160
161
162
+ silhouette_params = self .model_parameters ['Silhouette' ]
163
+ self .silhouette_random_state = silhouette_params ['random_state' ]
164
+ self .run_silhouette_score = silhouette_params ['run_silhouette_score' ]
165
+
161
166
def fit (self , docs : List [str ], recursions : int = 1 ) -> None :
162
167
"""
163
168
Fit model based on given data
@@ -241,6 +246,11 @@ def fit_transform(self, docs: List[str], recursions: int = 1) -> List[int]:
241
246
min_cluster_size = int (self .model_parameters ['HDBSCAN' ]['min_cluster_size' ])
242
247
)
243
248
249
+ # Silhouette score is a metric used to calculate the goodness of a clustering technique
250
+ if self .run_silhouette_score :
251
+ self .silhouette_score = self .get_silhouette_score (embeddings ,topics )
252
+ print (f"Silhouette score: { self .silhouette_score } " )
253
+
244
254
# Lemmatise words to avoid similar words in top n keywords
245
255
if self .lemmatise :
246
256
docs = [
@@ -594,3 +604,20 @@ def load(self, destination: str) -> None:
594
604
595
605
with open (destination , 'rb' ) as file :
596
606
self .__dict__ = pickle .load (file )
607
+
608
+ def get_silhouette_score (self , X : np .ndarray , labels : np .ndarray ) -> float :
609
+ """
610
+ A Silhouette Coefficient or silhouette score is a metric used to calculate the goodness of a clustering technique
611
+ 1: Means clusters are well apart from each other and clearly distinguished
612
+ 0: Means clusters are indifferent, or we can say that the distance between clusters is not significant
613
+ -1: Means clusters are assigned in the wrong way
614
+
615
+ Parameters
616
+ ----------
617
+ X : np.ndarray
618
+ embeddings array
619
+ labels : np.ndarray
620
+ labels as predicted by the model
621
+ """
622
+ return silhouette_score (X [labels != - 1 ],labels [labels != - 1 ],random_state = self .silhouette_random_state )
623
+
0 commit comments