Skip to content

Commit 44adb4c

Browse files
committed
minor cleanup
1 parent 13f6142 commit 44adb4c

File tree

3 files changed

+38
-34
lines changed

3 files changed

+38
-34
lines changed

imodelsx/embgam/embed.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def generate_ngrams_list(
1111
parsing: str='',
1212
nlp_chunks=None,
1313
):
14-
"""Get list of ngrams from sentence
14+
"""Get list of ngrams from sentence using a tokenizer
1515
1616
Params
1717
------
@@ -131,19 +131,17 @@ def embed_and_sum_function(
131131
sentence = example
132132
# seqs = sentence
133133

134-
if isinstance(sentence, str):
135-
seqs = generate_ngrams_list(
136-
sentence, ngrams=ngrams, tokenizer_ngrams=tokenizer_ngrams,
137-
parsing=parsing, nlp_chunks=nlp_chunks, all_ngrams=all_ngrams,
138-
)
139-
elif isinstance(sentence, list):
140-
raise Exception('batched mode not supported')
141-
# seqs = list(map(generate_ngrams_list, sentence))
134+
assert isinstance(sentence, str), 'sentence must be a string (batched mode not supported)'
135+
seqs = generate_ngrams_list(
136+
sentence, ngrams=ngrams, tokenizer_ngrams=tokenizer_ngrams,
137+
parsing=parsing, nlp_chunks=nlp_chunks, all_ngrams=all_ngrams,
138+
)
139+
# seqs = list(map(generate_ngrams_list, sentence))
140+
142141

143-
# maybe a smarter way to deal with pooling here?
144142
seq_len = len(seqs)
145143
if seq_len == 0:
146-
seqs = ["dummy"]
144+
seqs = ["dummy"] # will multiply embedding by 0 so doesn't matter
147145

148146
if 'bert' in checkpoint.lower(): # has up to two keys, 'last_hidden_state', 'pooler_output'
149147
if not hasattr(tokenizer_embeddings, 'pad_token') or tokenizer_embeddings.pad_token is None:

imodelsx/embgam/embgam.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,23 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
168168
print('\tNothing to update!')
169169
return
170170

171-
# compute embeddings
172-
"""
173-
# Faster version that needs more memory
174-
tokens = tokenizer(ngrams_list, padding=args.padding,
175-
truncation=True, return_tensors="pt")
176-
tokens = tokens.to(device)
171+
embs = self._get_embs(ngrams_list, model, tokenizer_embeddings)
172+
if self.normalize_embs:
173+
embs = self.normalizer.transform(embs)
177174

178-
output = model(**tokens) # this takes a while....
179-
embs = output['pooler_output'].cpu().detach().numpy()
180-
return embs
175+
# save coefs
176+
coef_embs = self.linear.coef_.squeeze().transpose()
177+
linear_coef = embs @ coef_embs
178+
self.coefs_dict_ = {
179+
**coefs_dict_old,
180+
**{ngrams_list[i]: linear_coef[i]
181+
for i in range(len(ngrams_list))}
182+
}
183+
print('coefs_dict_ len', len(self.coefs_dict_))
184+
185+
def _get_embs(self, ngrams_list, model, tokenizer_embeddings):
186+
"""Get embeddings for a list of ngrams (not summed!)
181187
"""
182-
# Slower way to run things but won't run out of mem
183188
embs = []
184189
for i in tqdm(range(len(ngrams_list))):
185190
tokens = tokenizer_embeddings(
@@ -191,18 +196,19 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
191196
emb = emb.mean(axis=1)
192197
embs.append(emb)
193198
embs = np.array(embs).squeeze()
194-
if self.normalize_embs:
195-
embs = self.normalizer.transform(embs)
199+
return embs
196200

197-
# save coefs
198-
coef_embs = self.linear.coef_.squeeze().transpose()
199-
linear_coef = embs @ coef_embs
200-
self.coefs_dict_ = {
201-
**coefs_dict_old,
202-
**{ngrams_list[i]: linear_coef[i]
203-
for i in range(len(ngrams_list))}
204-
}
205-
print('coefs_dict_ len', len(self.coefs_dict_))
201+
"""
202+
# Faster version that needs more memory
203+
tokens = tokenizer(ngrams_list, padding=args.padding,
204+
truncation=True, return_tensors="pt")
205+
tokens = tokens.to(device)
206+
207+
output = model(**tokens) # this takes a while....
208+
embs = output['pooler_output'].cpu().detach().numpy()
209+
return embs
210+
"""
211+
206212

207213
def _get_ngrams_list(self, X):
208214
all_ngrams = set()
@@ -251,7 +257,7 @@ def _predict_cached(self, X, warn):
251257
n_unseen_ngrams = 0
252258
for x in X:
253259
pred = 0
254-
seqs = imodelsx.embgam.embed.generate_ngrams_list(
260+
seqs = imodelsx.embgam.embed.generate_ngraxms_list(
255261
x,
256262
ngrams=self.ngrams,
257263
tokenizer_ngrams=self.tokenizer_ngrams,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
setuptools.setup(
2323
name="imodelsx",
24-
version="0.04",
24+
version="0.05",
2525
author="Chandan Singh, John X. Morris",
2626
author_email="chansingh@microsoft.com",
2727
description="Library to explain a dataset in natural language.",

0 commit comments

Comments
 (0)