Skip to content

Commit f306381

Browse files
authored
Merge pull request #6 from csinva/divyanshuaggarwal-da_min_ngram
Min frequency ngrams from @divyanshuaggarwal
2 parents 51754dd + 18a1fbe commit f306381

File tree

8 files changed

+245
-78
lines changed

8 files changed

+245
-78
lines changed

docs/auggam/auggam.html

Lines changed: 107 additions & 33 deletions
Large diffs are not rendered by default.

docs/auggam/embed.html

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,14 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
2929
<pre><code class="python">from transformers import BertModel, DistilBertModel
3030
from transformers import AutoModelForCausalLM
3131
from os.path import join as oj
32+
from datasets import Dataset
33+
from tqdm import tqdm
3234
import torch
35+
import numpy as np
36+
from torch.utils.data import DataLoader
3337
import imodelsx.util
3438

39+
3540
def get_model(checkpoint):
3641
if &#39;distilbert&#39; in checkpoint.lower():
3742
model = DistilBertModel.from_pretrained(checkpoint)
@@ -79,7 +84,8 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
7984
checkpoint: str,
8085
dataset_key_text: str = None,
8186
layer: str = &#39;last_hidden_state&#39;,
82-
padding: bool = True,
87+
padding: str = &#34;max_length&#34;,
88+
batch_size: int = 8,
8389
parsing: str = &#39;&#39;,
8490
nlp_chunks=None,
8591
all_ngrams: bool = False,
@@ -134,27 +140,57 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
134140
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
135141
tokens = tokenizer_embeddings(seqs, padding=padding,
136142
truncation=True, return_tensors=&#34;pt&#34;)
137-
tokens = tokens.to(model.device)
138-
output = model(**tokens)
139-
if layer == &#39;pooler_output&#39;:
140-
embs = output[&#39;pooler_output&#39;].cpu().detach().numpy()
141-
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
142-
embs = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
143-
embs = embs.mean(axis=1)
143+
144+
embs = []
145+
146+
ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)
147+
148+
for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
149+
batch = {k: v.to(model.device) for k, v in batch.items()}
150+
151+
with torch.no_grad():
152+
output = model(**batch)
153+
torch.cuda.empty_cache()
154+
155+
if layer == &#39;pooler_output&#39;:
156+
emb = output[&#39;pooler_output&#39;].cpu().detach().numpy()
157+
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
158+
emb = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
159+
emb = emb.mean(axis=1)
160+
161+
embs.append(emb)
162+
163+
embs = np.concatenate(embs)
164+
144165
elif &#39;gpt&#39; in checkpoint.lower():
145166
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
146-
tokens = tokens.to(model.device)
147-
output = model(**tokens)
148-
149-
# tuple of (layer x (batch_size, seq_len, hidden_size))
150-
h = output[&#39;hidden_states&#39;]
151-
# (batch_size, seq_len, hidden_size)
152-
embs = h[0].cpu().detach().numpy()
153-
embs = embs.mean(axis=1) # (batch_size, hidden_size)
167+
168+
embs = []
169+
170+
ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)
171+
172+
for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
173+
batch = {k: v.to(model.device) for k, v in batch.items()}
174+
175+
with torch.no_grad():
176+
output = model(**batch)
177+
torch.cuda.empty_cache()
178+
179+
# tuple of (layer x (batch_size, seq_len, hidden_size))
180+
h = output[&#39;hidden_states&#39;]
181+
# (batch_size, seq_len, hidden_size)
182+
emb = h[0].cpu().detach().numpy()
183+
emb = emb.mean(axis=1) # (batch_size, hidden_size)
184+
185+
embs.append(emb)
186+
187+
embs = np.concatenate(embs)
188+
154189
elif checkpoint.startswith(&#39;hkunlp/instructor&#39;):
155190
if instructor_prompt is None:
156191
instructor_prompt = &#34;Represent the short phrase for sentiment classification: &#34;
157-
embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
192+
embs = model.encode([[instructor_prompt, x_i]
193+
for x_i in seqs], batch_size=batch_size)
158194

159195
# sum over the embeddings
160196
embs = embs.sum(axis=0).reshape(1, -1)
@@ -172,7 +208,7 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
172208
<h2 class="section-title" id="header-functions">Functions</h2>
173209
<dl>
174210
<dt id="imodelsx.auggam.embed.embed_and_sum_function"><code class="name flex">
175-
<span>def <span class="ident">embed_and_sum_function</span></span>(<span>example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: bool = True, parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</span>
211+
<span>def <span class="ident">embed_and_sum_function</span></span>(<span>example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: str = 'max_length', batch_size: int = 8, parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</span>
176212
</code></dt>
177213
<dd>
178214
<div class="desc"><p>Get summed embeddings for a single example</p>
@@ -206,7 +242,8 @@ <h2 id="params">Params</h2>
206242
checkpoint: str,
207243
dataset_key_text: str = None,
208244
layer: str = &#39;last_hidden_state&#39;,
209-
padding: bool = True,
245+
padding: str = &#34;max_length&#34;,
246+
batch_size: int = 8,
210247
parsing: str = &#39;&#39;,
211248
nlp_chunks=None,
212249
all_ngrams: bool = False,
@@ -261,27 +298,57 @@ <h2 id="params">Params</h2>
261298
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
262299
tokens = tokenizer_embeddings(seqs, padding=padding,
263300
truncation=True, return_tensors=&#34;pt&#34;)
264-
tokens = tokens.to(model.device)
265-
output = model(**tokens)
266-
if layer == &#39;pooler_output&#39;:
267-
embs = output[&#39;pooler_output&#39;].cpu().detach().numpy()
268-
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
269-
embs = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
270-
embs = embs.mean(axis=1)
301+
302+
embs = []
303+
304+
ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)
305+
306+
for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
307+
batch = {k: v.to(model.device) for k, v in batch.items()}
308+
309+
with torch.no_grad():
310+
output = model(**batch)
311+
torch.cuda.empty_cache()
312+
313+
if layer == &#39;pooler_output&#39;:
314+
emb = output[&#39;pooler_output&#39;].cpu().detach().numpy()
315+
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
316+
emb = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
317+
emb = emb.mean(axis=1)
318+
319+
embs.append(emb)
320+
321+
embs = np.concatenate(embs)
322+
271323
elif &#39;gpt&#39; in checkpoint.lower():
272324
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
273-
tokens = tokens.to(model.device)
274-
output = model(**tokens)
275-
276-
# tuple of (layer x (batch_size, seq_len, hidden_size))
277-
h = output[&#39;hidden_states&#39;]
278-
# (batch_size, seq_len, hidden_size)
279-
embs = h[0].cpu().detach().numpy()
280-
embs = embs.mean(axis=1) # (batch_size, hidden_size)
325+
326+
embs = []
327+
328+
ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)
329+
330+
for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
331+
batch = {k: v.to(model.device) for k, v in batch.items()}
332+
333+
with torch.no_grad():
334+
output = model(**batch)
335+
torch.cuda.empty_cache()
336+
337+
# tuple of (layer x (batch_size, seq_len, hidden_size))
338+
h = output[&#39;hidden_states&#39;]
339+
# (batch_size, seq_len, hidden_size)
340+
emb = h[0].cpu().detach().numpy()
341+
emb = emb.mean(axis=1) # (batch_size, hidden_size)
342+
343+
embs.append(emb)
344+
345+
embs = np.concatenate(embs)
346+
281347
elif checkpoint.startswith(&#39;hkunlp/instructor&#39;):
282348
if instructor_prompt is None:
283349
instructor_prompt = &#34;Represent the short phrase for sentiment classification: &#34;
284-
embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
350+
embs = model.encode([[instructor_prompt, x_i]
351+
for x_i in seqs], batch_size=batch_size)
285352

286353
# sum over the embeddings
287354
embs = embs.sum(axis=0).reshape(1, -1)

docs/index.html

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,12 @@
8383
<td>Black-box model</td>
8484
<td>Finetune a single linear layer<br/>on top of LLM embeddings</td>
8585
</tr>
86-
<tr>
87-
<td style="text-align: left;">(Coming soon!)</td>
88-
<td></td>
89-
<td></td>
90-
<td>We plan to support other interpretable models like <a href="https://arxiv.org/abs/2205.12548">RLPrompt</a>, <a href="https://arxiv.org/abs/2007.04612">CBMs</a>, <a href="https://proceedings.neurips.cc/paper/2021/hash/251bd0442dfcc53b5a761e050f8022b8-Abstract.html">NAMs</a>, and <a href="https://arxiv.org/abs/2004.00221">NBDT</a></td>
91-
</tr>
9286
</tbody>
9387
</table>
9488
<p align="center">
95-
Demo notebooks <a href="https://github.com/csinva/imodelsX/tree/master/demo_notebooks">📖</a>, Doc <a href="https://csinva.io/imodelsX/">🗂️</a>, Reference code implementation 🔗, Research paper 📄
89+
<a href="https://github.com/csinva/imodelsX/tree/master/demo_notebooks">📖</a>Demo notebooks &emsp; <a href="https://csinva.io/imodelsX/">🗂️</a> Doc &emsp; 🔗 Reference code &emsp; 📄 Research paper
9690
</br>
91+
⌛ We plan to support other interpretable algorithms like <a href="https://arxiv.org/abs/2205.12548">RLPrompt</a>, <a href="https://arxiv.org/abs/2007.04612">CBMs</a>, and <a href="https://arxiv.org/abs/2004.00221">NBDT</a>. If you want to contribute an algorithm, feel free to open a PR 😄
9792
</p>
9893
<h1 id="quickstart">Quickstart</h1>
9994
<p><strong>Installation</strong>: <code>pip install <a title="imodelsx" href="#imodelsx">imodelsx</a></code> (or, for more control, clone and install from source)</p>

docs/util.html

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
3232
from transformers import pipeline
3333
import datasets
3434
import numpy as np
35+
from collections import Counter
3536

3637

3738
def generate_ngrams_list(
@@ -42,6 +43,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
4243
parsing: str = &#39;&#39;,
4344
nlp_chunks=None,
4445
pad_starting_ngrams=False,
46+
min_frequency=1,
4547
):
4648
&#34;&#34;&#34;Get list of ngrams from sentence using a tokenizer
4749

@@ -55,6 +57,8 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
5557
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
5658
so that length of ngrams_list is the same as the initial sequence
5759
e.g. for ngrams=3 [&#34;the&#34;, &#34;the quick&#34;, &#34;the quick brown&#34;, &#34;quick brown fox&#34;, &#34;brown fox jumps&#34;, ...]
60+
min_frequency: int
61+
minimum frequency to be considered for the ngrams_list
5862
&#34;&#34;&#34;
5963

6064
seqs = []
@@ -96,6 +100,10 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
96100
assert all_ngrams is False, &#34;pad_starting_ngrams only works when all_ngrams=False&#34;
97101
seqs_init = [&#39; &#39;.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
98102
seqs = seqs_init + seqs
103+
104+
freqs = Counter(seqs)
105+
106+
seqs = [seq for seq, freq in freqs.items() if freq &gt;= min_frequency]
99107

100108
return seqs
101109

@@ -177,7 +185,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
177185
<h2 class="section-title" id="header-functions">Functions</h2>
178186
<dl>
179187
<dt id="imodelsx.util.generate_ngrams_list"><code class="name flex">
180-
<span>def <span class="ident">generate_ngrams_list</span></span>(<span>sentence: str, ngrams: int, tokenizer_ngrams=None, all_ngrams=False, parsing: str = '', nlp_chunks=None, pad_starting_ngrams=False)</span>
188+
<span>def <span class="ident">generate_ngrams_list</span></span>(<span>sentence: str, ngrams: int, tokenizer_ngrams=None, all_ngrams=False, parsing: str = '', nlp_chunks=None, pad_starting_ngrams=False, min_frequency=1)</span>
181189
</code></dt>
182190
<dd>
183191
<div class="desc"><p>Get list of ngrams from sentence using a tokenizer</p>
@@ -189,7 +197,9 @@ <h2 id="params">Params</h2>
189197
pad_starting_ngrams: bool
190198
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
191199
so that length of ngrams_list is the same as the initial sequence
192-
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", &hellip;]</p></div>
200+
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", &hellip;]
201+
min_frequency: int
202+
minimum frequency to be considered for the ngrams_list</p></div>
193203
<details class="source">
194204
<summary>
195205
<span>Expand source code</span>
@@ -202,6 +212,7 @@ <h2 id="params">Params</h2>
202212
parsing: str = &#39;&#39;,
203213
nlp_chunks=None,
204214
pad_starting_ngrams=False,
215+
min_frequency=1,
205216
):
206217
&#34;&#34;&#34;Get list of ngrams from sentence using a tokenizer
207218

@@ -215,6 +226,8 @@ <h2 id="params">Params</h2>
215226
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
216227
so that length of ngrams_list is the same as the initial sequence
217228
e.g. for ngrams=3 [&#34;the&#34;, &#34;the quick&#34;, &#34;the quick brown&#34;, &#34;quick brown fox&#34;, &#34;brown fox jumps&#34;, ...]
229+
min_frequency: int
230+
minimum frequency to be considered for the ngrams_list
218231
&#34;&#34;&#34;
219232

220233
seqs = []
@@ -256,6 +269,10 @@ <h2 id="params">Params</h2>
256269
assert all_ngrams is False, &#34;pad_starting_ngrams only works when all_ngrams=False&#34;
257270
seqs_init = [&#39; &#39;.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
258271
seqs = seqs_init + seqs
272+
273+
freqs = Counter(seqs)
274+
275+
seqs = [seq for seq, freq in freqs.items() if freq &gt;= min_frequency]
259276

260277
return seqs</code></pre>
261278
</details>

imodelsx/auggam/auggam.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
layer: str = 'last_hidden_state',
3737
ngrams: int = 2,
3838
all_ngrams: bool = False,
39+
min_frequency: int = 1,
3940
tokenizer_ngrams=None,
4041
random_state=None,
4142
normalize_embs=False,
@@ -54,6 +55,8 @@ def __init__(
5455
Order of ngrams to extract. 1 for unigrams, 2 for bigrams, etc.
5556
all_ngrams
5657
Whether to use all order ngrams <= ngrams argument
58+
min_frequency
59+
minimum frequency of ngrams to be kept in the ngrams list.
5760
tokenizer_ngrams
5861
if None, defaults to spacy English tokenizer
5962
random_state
@@ -76,6 +79,7 @@ def __init__(
7679
self.layer = layer
7780
self.random_state = random_state
7881
self.all_ngrams = all_ngrams
82+
self.min_frequency = min_frequency
7983
self.normalize_embs = normalize_embs
8084
self.fit_with_ngram_decomposition = fit_with_ngram_decomposition
8185
self.instructor_prompt = instructor_prompt
@@ -284,6 +288,7 @@ def _get_ngrams_list(self, X):
284288
ngrams=self.ngrams,
285289
tokenizer_ngrams=self.tokenizer_ngrams,
286290
all_ngrams=self.all_ngrams,
291+
min_frequency=self.min_frequency
287292
)
288293
all_ngrams |= set(seqs)
289294
return sorted(list(all_ngrams))

imodelsx/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from transformers import pipeline
77
import datasets
88
import numpy as np
9+
from collections import Counter
910

1011

1112
def generate_ngrams_list(
@@ -16,6 +17,7 @@ def generate_ngrams_list(
1617
parsing: str = '',
1718
nlp_chunks=None,
1819
pad_starting_ngrams=False,
20+
min_frequency=1,
1921
):
2022
"""Get list of ngrams from sentence using a tokenizer
2123
@@ -29,6 +31,8 @@ def generate_ngrams_list(
2931
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
3032
so that length of ngrams_list is the same as the initial sequence
3133
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", ...]
34+
min_frequency: int
35+
minimum frequency to be considered for the ngrams_list
3236
"""
3337

3438
seqs = []
@@ -70,6 +74,10 @@ def generate_ngrams_list(
7074
assert all_ngrams is False, "pad_starting_ngrams only works when all_ngrams=False"
7175
seqs_init = [' '.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
7276
seqs = seqs_init + seqs
77+
78+
freqs = Counter(seqs)
79+
80+
seqs = [seq for seq, freq in freqs.items() if freq >= min_frequency]
7381

7482
return seqs
7583

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
setuptools.setup(
2626
name="imodelsx",
27-
version="0.20",
27+
version="0.21",
2828
author="Chandan Singh, John X. Morris, Armin Askari",
2929
author_email="chansingh@microsoft.com",
3030
description="Library to explain a dataset in natural language.",

tests/test_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
1616
ngrams=2,
1717
all_ngrams=True, # also use lower-order ngrams
18+
min_frequency=1
1819
)
1920
m.fit(dset['text'], dset['label'], batch_size=8)
2021

0 commit comments

Comments
 (0)