Skip to content

Commit 9d4b5be

Browse files
feat(buzzwords): [0.3.1] silhouette score (#7)
1 parent 70df9c4 commit 9d4b5be

File tree

5 files changed

+34
-3
lines changed

5 files changed

+34
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ _site/*
66
buzzwords.egg-info/*
77
Gemfile.lock
88

9-
.DS_Store
9+
.DS_Store
10+
__pycache__

buzzwords/buzzwords.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from cuml.cluster import approximate_predict
1111
from sentence_transformers import SentenceTransformer
1212
from sklearn.metrics.pairwise import cosine_similarity
13+
from sklearn.metrics import silhouette_score
1314
from tqdm import tqdm
1415

1516
from .models.clip_encoder import CLIPEncoder
@@ -158,6 +159,10 @@ def __init__(self,
158159
self.topic_descriptions = None
159160
self.topic_alterations = {}
160161

162+
silhouette_params = self.model_parameters['Silhouette']
163+
self.silhouette_random_state = silhouette_params['random_state']
164+
self.run_silhouette_score = silhouette_params['run_silhouette_score']
165+
161166
def fit(self, docs: List[str], recursions: int = 1) -> None:
162167
"""
163168
Fit model based on given data
@@ -241,6 +246,11 @@ def fit_transform(self, docs: List[str], recursions: int = 1) -> List[int]:
241246
min_cluster_size=int(self.model_parameters['HDBSCAN']['min_cluster_size'])
242247
)
243248

249+
# Silhouette score is a metric used to calculate the goodness of a clustering technique
250+
if self.run_silhouette_score:
251+
self.silhouette_score = self.get_silhouette_score(embeddings,topics)
252+
print(f"Silhouette score: {self.silhouette_score}")
253+
244254
# Lemmatise words to avoid similar words in top n keywords
245255
if self.lemmatise:
246256
docs = [
@@ -594,3 +604,20 @@ def load(self, destination: str) -> None:
594604

595605
with open(destination, 'rb') as file:
596606
self.__dict__ = pickle.load(file)
607+
608+
def get_silhouette_score(self, X: np.ndarray, labels: np.ndarray) -> float:
609+
"""
610+
A Silhouette Coefficient or silhouette score is a metric used to calculate the goodness of a clustering technique
611+
1: Means clusters are well apart from each other and clearly distinguished
612+
0: Means clusters are indifferent, or we can say that the distance between clusters is not significant
613+
-1: Means clusters are assigned in the wrong way
614+
615+
Parameters
616+
----------
617+
X : np.ndarray
618+
embeddings array
619+
labels : np.ndarray
620+
labels as predicted by the model
621+
"""
622+
return silhouette_score(X[labels!=-1],labels[labels!=-1],random_state=self.silhouette_random_state)
623+

buzzwords/model_parameters.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Default parameters for Buzzwords model
22
Embedding:
33
model_name_or_path: 'all-mpnet-base-v2'
4+
Silhouette:
5+
random_state: 42
6+
run_silhouette_score: False
47
UMAP:
58
n_neighbors: 10
69
n_components: 5

install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ conda create -y -n $env_name \
1010
source activate $env_name;
1111

1212
pip3 install \
13-
sentence-transformers==2.1.0 \
13+
sentence-transformers==2.2.2 \
1414
keybert==0.5.1 \
1515
pytest~=7.0.0 \
1616
clip-by-openai==1.1;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
setuptools.setup(
44
name='buzzwords',
5-
version='0.3.0',
5+
version='0.3.1',
66
packages=setuptools.find_packages()
77
)

0 commit comments

Comments
 (0)