Skip to content

Commit e65fafb

Browse files
authored
Merge pull request #874 from TransformerLensOrg/dev
Release v2.15.0
2 parents 5e328e9 + 555f355 commit e65fafb

17 files changed

+1091
-149
lines changed

demos/BERT.ipynb

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"metadata": {},
1616
"source": [
1717
"# BERT in TransformerLens\n",
18-
"This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task."
18+
"This demo shows how to use BERT in TransformerLens for the Masked Language Modelling and Next Sentence Prediction task."
1919
]
2020
},
2121
{
@@ -29,16 +29,14 @@
2929
},
3030
{
3131
"cell_type": "code",
32-
"execution_count": 15,
32+
"execution_count": 1,
3333
"metadata": {},
3434
"outputs": [
3535
{
3636
"name": "stdout",
3737
"output_type": "stream",
3838
"text": [
39-
"Running as a Jupyter notebook - intended for development only!\n",
40-
"The autoreload extension is already loaded. To reload it, use:\n",
41-
" %reload_ext autoreload\n"
39+
"Running as a Jupyter notebook - intended for development only!\n"
4240
]
4341
},
4442
{
@@ -92,7 +90,7 @@
9290
},
9391
{
9492
"cell_type": "code",
95-
"execution_count": 3,
93+
"execution_count": 2,
9694
"metadata": {},
9795
"outputs": [
9896
{
@@ -116,7 +114,7 @@
116114
},
117115
{
118116
"cell_type": "code",
119-
"execution_count": 4,
117+
"execution_count": 3,
120118
"metadata": {},
121119
"outputs": [
122120
{
@@ -136,7 +134,7 @@
136134
"<circuitsvis.utils.render.RenderedHTML at 0x13a9760d0>"
137135
]
138136
},
139-
"execution_count": 4,
137+
"execution_count": 3,
140138
"metadata": {},
141139
"output_type": "execute_result"
142140
}
@@ -150,7 +148,7 @@
150148
},
151149
{
152150
"cell_type": "code",
153-
"execution_count": 5,
151+
"execution_count": 4,
154152
"metadata": {},
155153
"outputs": [],
156154
"source": [
@@ -159,12 +157,12 @@
159157
"\n",
160158
"from transformers import AutoTokenizer\n",
161159
"\n",
162-
"from transformer_lens import HookedEncoder"
160+
"from transformer_lens import HookedEncoder, BertNextSentencePrediction"
163161
]
164162
},
165163
{
166164
"cell_type": "code",
167-
"execution_count": 6,
165+
"execution_count": 5,
168166
"metadata": {},
169167
"outputs": [
170168
{
@@ -173,7 +171,7 @@
173171
"<torch.autograd.grad_mode.set_grad_enabled at 0x2a285a790>"
174172
]
175173
},
176-
"execution_count": 6,
174+
"execution_count": 5,
177175
"metadata": {},
178176
"output_type": "execute_result"
179177
}
@@ -189,12 +187,12 @@
189187
"source": [
190188
"# BERT\n",
191189
"\n",
192-
"In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task"
190+
"In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling and Next Sentence Prediction task"
193191
]
194192
},
195193
{
196194
"cell_type": "code",
197-
"execution_count": 14,
195+
"execution_count": 6,
198196
"metadata": {},
199197
"outputs": [
200198
{
@@ -225,37 +223,132 @@
225223
"cell_type": "markdown",
226224
"metadata": {},
227225
"source": [
228-
"Use the \"[MASK]\" token to mask any tokens which you would like the model to predict."
226+
"## Masked Language Modelling\n",
227+
"Use the \"[MASK]\" token to mask any tokens which you would like the model to predict. \n",
228+
"When specifying return_type=\"predictions\" the prediction of the model is returned, alternatively (and by default) the function returns logits. \n",
229+
"You can also specify None as return type for which nothing is returned"
229230
]
230231
},
231232
{
232233
"cell_type": "code",
233-
"execution_count": 11,
234+
"execution_count": 7,
234235
"metadata": {},
235-
"outputs": [],
236+
"outputs": [
237+
{
238+
"name": "stdout",
239+
"output_type": "stream",
240+
"text": [
241+
"Prompt: The [MASK] is bright today.\n",
242+
"Prediction: \"sun\"\n"
243+
]
244+
}
245+
],
246+
"source": [
247+
"prompt = \"The [MASK] is bright today.\"\n",
248+
"\n",
249+
"prediction = bert(prompt, return_type=\"predictions\")\n",
250+
"\n",
251+
"print(f\"Prompt: {prompt}\")\n",
252+
"print(f'Prediction: \"{prediction}\"')"
253+
]
254+
},
255+
{
256+
"cell_type": "markdown",
257+
"metadata": {},
258+
"source": [
259+
"You can also input a list of prompts:"
260+
]
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": 8,
265+
"metadata": {},
266+
"outputs": [
267+
{
268+
"name": "stdout",
269+
"output_type": "stream",
270+
"text": [
271+
"Prompt: ['The [MASK] is bright today.', 'She [MASK] to the store.', 'The dog [MASK] the ball.']\n",
272+
"Prediction: \"['Prediction 0: sun', 'Prediction 1: went', 'Prediction 2: caught']\"\n"
273+
]
274+
}
275+
],
276+
"source": [
277+
"prompts = [\"The [MASK] is bright today.\", \"She [MASK] to the store.\", \"The dog [MASK] the ball.\"]\n",
278+
"\n",
279+
"predictions = bert(prompts, return_type=\"predictions\")\n",
280+
"\n",
281+
"print(f\"Prompt: {prompts}\")\n",
282+
"print(f'Prediction: \"{predictions}\"')"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"## Next Sentence Prediction\n",
290+
"To carry out Next Sentence Prediction, you have to use the class BertNextSentencePrediction, and pass a HookedEncoder in its constructor. \n",
291+
"Then, create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function. \n",
292+
"The model will then predict the probability of the sentence at position 1 following (i.e. being the next sentence) to the sentence at position 0."
293+
]
294+
},
295+
{
296+
"cell_type": "code",
297+
"execution_count": 9,
298+
"metadata": {},
299+
"outputs": [
300+
{
301+
"name": "stdout",
302+
"output_type": "stream",
303+
"text": [
304+
"Sentence A: A man walked into a grocery store.\n",
305+
"Sentence B: He bought an apple.\n",
306+
"Prediction: \"The sentences are sequential\"\n"
307+
]
308+
}
309+
],
236310
"source": [
237-
"prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n",
311+
"nsp = BertNextSentencePrediction(bert)\n",
312+
"sentence_a = \"A man walked into a grocery store.\"\n",
313+
"sentence_b = \"He bought an apple.\"\n",
238314
"\n",
239-
"input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
240-
"mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()"
315+
"input = [sentence_a, sentence_b]\n",
316+
"\n",
317+
"predictions = nsp(input, return_type=\"predictions\")\n",
318+
"\n",
319+
"print(f\"Sentence A: {sentence_a}\")\n",
320+
"print(f\"Sentence B: {sentence_b}\")\n",
321+
"print(f'Prediction: \"{predictions}\"')"
322+
]
323+
},
324+
{
325+
"cell_type": "markdown",
326+
"metadata": {},
327+
"source": [
328+
"# Inputting tokens directly\n",
329+
"You can also input tokens instead of a string or a list of strings into the model, which could look something like this"
241330
]
242331
},
243332
{
244333
"cell_type": "code",
245-
"execution_count": 12,
334+
"execution_count": 10,
246335
"metadata": {},
247336
"outputs": [
248337
{
249338
"name": "stdout",
250339
"output_type": "stream",
251340
"text": [
252-
"Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n",
253-
"Prediction: \"Systems\"\n"
341+
"Prompt: The [MASK] is bright today.\n",
342+
"Prediction: \"sun\"\n"
254343
]
255344
}
256345
],
257346
"source": [
258-
"logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
347+
"prompt = \"The [MASK] is bright today.\"\n",
348+
"\n",
349+
"tokens = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
350+
"logits = bert(tokens) # Since we are not specifying return_type, we get the logits\n",
351+
"logprobs = logits[tokens == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
259352
"prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
260353
"\n",
261354
"print(f\"Prompt: {prompt}\")\n",
@@ -267,13 +360,13 @@
267360
"cell_type": "markdown",
268361
"metadata": {},
269362
"source": [
270-
"Better luck next time, BERT."
363+
"Well done, BERT!"
271364
]
272365
}
273366
],
274367
"metadata": {
275368
"kernelspec": {
276-
"display_name": ".venv",
369+
"display_name": "Python 3",
277370
"language": "python",
278371
"name": "python3"
279372
},
@@ -287,7 +380,7 @@
287380
"name": "python",
288381
"nbconvert_exporter": "python",
289382
"pygments_lexer": "ipython3",
290-
"version": "3.11.8"
383+
"version": "3.10.15"
291384
},
292385
"orig_nbformat": 4
293386
},

0 commit comments

Comments
 (0)