Skip to content

Commit 11edb28

Browse files
authored
Add mistral 7b support (#443)
* add mistral model name and alias * add code for converting mistral config to hooked transformer config * add function for converting mistral weights * add GroupedQueryAttention * add abstract base class for attention * adapt keyvaluecache if grouped query attention is used * fix fold_value_biases when using grouped query attention * Add unit test for grouped query attention * Add demo notebook for Mistral * fix formatting * add documentation for grouped query attention * update lock file * use Union instead of | for union types * hardcode mistral config so building docs works with older versions of transformers * don't set final_rms in Mistral config * make Mistral-7b's alias name consistent with other models * update main demo notebook * require transformers>=3.34 * improve docstrings and clarify test name for grouped query attention * remove Mistral demo * fix docstring format
1 parent a5147ba commit 11edb28

File tree

9 files changed

+1398
-724
lines changed

9 files changed

+1398
-724
lines changed

demos/Main_Demo.ipynb

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
},
4646
{
4747
"cell_type": "code",
48-
"execution_count": 292,
48+
"execution_count": 11,
4949
"metadata": {},
5050
"outputs": [],
5151
"source": [
@@ -80,7 +80,7 @@
8080
},
8181
{
8282
"cell_type": "code",
83-
"execution_count": 293,
83+
"execution_count": 12,
8484
"metadata": {},
8585
"outputs": [
8686
{
@@ -103,32 +103,28 @@
103103
},
104104
{
105105
"cell_type": "code",
106-
"execution_count": 294,
106+
"execution_count": 13,
107107
"metadata": {},
108108
"outputs": [
109109
{
110110
"data": {
111111
"text/html": [
112-
"<div id=\"circuits-vis-7e4c8a75-1335\" style=\"margin: 15px 0;\"/>\n",
112+
"<div id=\"circuits-vis-1f2a8687-9cd7\" style=\"margin: 15px 0;\"/>\n",
113113
" <script crossorigin type=\"module\">\n",
114-
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.0/dist/cdn/esm.js\";\n",
114+
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
115115
" render(\n",
116-
" \"circuits-vis-7e4c8a75-1335\",\n",
116+
" \"circuits-vis-1f2a8687-9cd7\",\n",
117117
" Hello,\n",
118118
" {\"name\": \"Neel\"}\n",
119119
" )\n",
120120
" </script>"
121121
],
122122
"text/plain": [
123-
"<circuitsvis.utils.render.RenderedHTML at 0xffff10cc9f10>"
123+
"<circuitsvis.utils.render.RenderedHTML at 0x7f21437f1c30>"
124124
]
125125
},
126-
"execution_count": 294,
127-
"metadata": {
128-
"text/html": {
129-
"Content-Type": "text/html"
130-
}
131-
},
126+
"execution_count": 13,
127+
"metadata": {},
132128
"output_type": "execute_result"
133129
}
134130
],
@@ -140,7 +136,7 @@
140136
},
141137
{
142138
"cell_type": "code",
143-
"execution_count": 295,
139+
"execution_count": 14,
144140
"metadata": {},
145141
"outputs": [],
146142
"source": [
@@ -158,7 +154,7 @@
158154
},
159155
{
160156
"cell_type": "code",
161-
"execution_count": 296,
157+
"execution_count": 15,
162158
"metadata": {},
163159
"outputs": [],
164160
"source": [
@@ -179,16 +175,16 @@
179175
},
180176
{
181177
"cell_type": "code",
182-
"execution_count": 297,
178+
"execution_count": 16,
183179
"metadata": {},
184180
"outputs": [
185181
{
186182
"data": {
187183
"text/plain": [
188-
"<torch.autograd.grad_mode.set_grad_enabled at 0xffff425948e0>"
184+
"<torch.autograd.grad_mode.set_grad_enabled at 0x7f213de735e0>"
189185
]
190186
},
191-
"execution_count": 297,
187+
"execution_count": 16,
192188
"metadata": {},
193189
"output_type": "execute_result"
194190
}
@@ -254,7 +250,7 @@
254250
},
255251
{
256252
"cell_type": "code",
257-
"execution_count": 299,
253+
"execution_count": 17,
258254
"metadata": {},
259255
"outputs": [],
260256
"source": [
@@ -263,7 +259,7 @@
263259
},
264260
{
265261
"cell_type": "code",
266-
"execution_count": 300,
262+
"execution_count": 18,
267263
"metadata": {},
268264
"outputs": [
269265
{
@@ -1210,21 +1206,21 @@
12101206
},
12111207
{
12121208
"cell_type": "code",
1213-
"execution_count": 314,
1209+
"execution_count": 19,
12141210
"metadata": {},
12151211
"outputs": [
12161212
{
12171213
"name": "stdout",
12181214
"output_type": "stream",
12191215
"text": [
12201216
"blocks.0.attn.W_Q torch.Size([12, 768, 64])\n",
1221-
"blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
1222-
"blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
12231217
"blocks.0.attn.W_O torch.Size([12, 64, 768])\n",
12241218
"blocks.0.attn.b_Q torch.Size([12, 64])\n",
1219+
"blocks.0.attn.b_O torch.Size([768])\n",
1220+
"blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
1221+
"blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
12251222
"blocks.0.attn.b_K torch.Size([12, 64])\n",
12261223
"blocks.0.attn.b_V torch.Size([12, 64])\n",
1227-
"blocks.0.attn.b_O torch.Size([768])\n",
12281224
"blocks.0.mlp.W_in torch.Size([768, 3072])\n",
12291225
"blocks.0.mlp.b_in torch.Size([3072])\n",
12301226
"blocks.0.mlp.W_out torch.Size([3072, 768])\n",
@@ -1247,7 +1243,7 @@
12471243
},
12481244
{
12491245
"cell_type": "code",
1250-
"execution_count": 315,
1246+
"execution_count": 20,
12511247
"metadata": {},
12521248
"outputs": [
12531249
{

0 commit comments

Comments
 (0)