You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ch05/07_gpt_to_llama/README.md
+38-11Lines changed: 38 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -17,13 +17,16 @@ This folder contains code for converting the GPT implementation from chapter 4 a
17
17
For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
18
18
19
19
20
-
#####1) Installation
20
+
#### 1) Installation
21
21
22
22
```bash
23
23
pip install llms_from_scratch blobfile
24
24
```
25
+
26
+
(Note that `blobfile` is needed to load the tokenizer.)
27
+
25
28
26
-
#####2) Model and text generation settings
29
+
#### 2) Model and text generation settings
27
30
28
31
Specify which model to use:
29
32
@@ -51,7 +54,7 @@ TOP_K = 1
51
54
```
52
55
53
56
54
-
#####3) Weight download and loading
57
+
#### 3) Weight download and loading
55
58
56
59
This automatically downloads the weight file based on the model choice above:
When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
160
165
161
166
```
162
-
Time: 4.12 sec
167
+
Time: 3.17 sec
168
+
50 tokens/sec
163
169
Max memory allocated: 2.91 GB
164
170
165
171
@@ -176,7 +182,22 @@ It's worth noting that the specific diet of llamas can vary depending on factors
176
182
```
177
183
178
184
179
-
**Pro tip**
185
+
#### Pro tip 1: speed up inference with FlashAttention
186
+
187
+
Instead of using `Llama3Model`, you can use `Llama3ModelFast` as a drop-in replacement. For more information, I encourage you to inspect the [pkg/llms_from_scratch/llama3.py](../../pkg/llms_from_scratch/llama3.py) code.
188
+
189
+
The `Llama3ModelFast` replaces my from-scratch scaled dot-product code in the `GroupedQueryAttention` module with PyTorch's `scaled_dot_product` function, which uses `FlashAttention` on Ampere GPUs or newer.
190
+
191
+
The following table shows a performance comparison on an A100:
192
+
193
+
|| Tokens/sec | Memory |
194
+
| --------------- | ---------- | ------- |
195
+
| Llama3Model | 50 | 2.91 GB |
196
+
| Llama3ModelFast | 58 | 2.85 GB |
197
+
198
+
199
+
#### Pro tip 2: speed up inference with compilation
200
+
180
201
181
202
For up to a 4× speed-up, replace
182
203
@@ -191,5 +212,11 @@ model = torch.compile(model)
191
212
model.to(device)
192
213
```
193
214
194
-
Note: the speed-up takes effect after the first `generate` call.
215
+
Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
216
+
217
+
The following table shows a performance comparison on an A100 for consequent `generate` calls:
0 commit comments