Skip to content

Commit 47c0360

Browse files
authored
Llama3 from scratch improvements (#621)
* Llama3 from scratch improvements * restore
1 parent 1b242d0 commit 47c0360

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

ch05/07_gpt_to_llama/README.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@ This folder contains code for converting the GPT implementation from chapter 4 a
1717
For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
1818

1919
 
20-
##### 1) Installation
20+
#### 1) Installation
2121

2222
```bash
2323
pip install llms_from_scratch blobfile
2424
```
25+
26+
(Note that `blobfile` is needed to load the tokenizer.)
27+
2528
 
26-
##### 2) Model and text generation settings
29+
#### 2) Model and text generation settings
2730

2831
Specify which model to use:
2932

@@ -51,7 +54,7 @@ TOP_K = 1
5154
```
5255

5356
 
54-
##### 3) Weight download and loading
57+
#### 3) Weight download and loading
5558

5659
This automatically downloads the weight file based on the model choice above:
5760

@@ -82,7 +85,7 @@ else:
8285
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
8386

8487
model = Llama3Model(LLAMA32_CONFIG)
85-
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
88+
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
8689

8790
device = (
8891
torch.device("cuda") if torch.cuda.is_available() else
@@ -93,7 +96,7 @@ model.to(device)
9396
```
9497

9598
 
96-
##### 4) Initialize tokenizer
99+
#### 4) Initialize tokenizer
97100

98101
The following code downloads and initializes the tokenizer:
99102

@@ -115,14 +118,14 @@ if "instruct" in MODEL_FILE:
115118
```
116119

117120
 
118-
##### 5) Generating text
121+
#### 5) Generating text
119122

120123
Lastly, we can generate text via the following code:
121124

122125
```python
123126
import time
124127

125-
from llms_from_scratch.ch05 import (
128+
from ch05 import (
126129
generate,
127130
text_to_token_ids,
128131
token_ids_to_text
@@ -141,7 +144,9 @@ token_ids = generate(
141144
temperature=TEMPERATURE
142145
)
143146

144-
print(f"Time: {time.time() - start:.2f} sec")
147+
total_time = time.time() - start
148+
print(f"Time: {total_time:.2f} sec")
149+
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
145150

146151
if torch.cuda.is_available():
147152
max_mem_bytes = torch.cuda.max_memory_allocated()
@@ -159,7 +164,8 @@ print("\n\nOutput text:\n\n", output_text)
159164
When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
160165

161166
```
162-
Time: 4.12 sec
167+
Time: 3.17 sec
168+
50 tokens/sec
163169
Max memory allocated: 2.91 GB
164170
165171
@@ -176,7 +182,22 @@ It's worth noting that the specific diet of llamas can vary depending on factors
176182
```
177183

178184
 
179-
**Pro tip**
185+
#### Pro tip 1: speed up inference with FlashAttention
186+
187+
Instead of using `Llama3Model`, you can use `Llama3ModelFast` as a drop-in replacement. For more information, I encourage you to inspect the [pkg/llms_from_scratch/llama3.py](../../pkg/llms_from_scratch/llama3.py) code.
188+
189+
The `Llama3ModelFast` replaces my from-scratch scaled dot-product code in the `GroupedQueryAttention` module with PyTorch's `scaled_dot_product` function, which uses `FlashAttention` on Ampere GPUs or newer.
190+
191+
The following table shows a performance comparison on an A100:
192+
193+
| | Tokens/sec | Memory |
194+
| --------------- | ---------- | ------- |
195+
| Llama3Model | 50 | 2.91 GB |
196+
| Llama3ModelFast | 58 | 2.85 GB |
197+
198+
 
199+
#### Pro tip 2: speed up inference with compilation
200+
180201

181202
For up to a 4× speed-up, replace
182203

@@ -191,5 +212,11 @@ model = torch.compile(model)
191212
model.to(device)
192213
```
193214

194-
Note: the speed-up takes effect after the first `generate` call.
215+
Note: There is a significant multi-minute upfront cost when compiling, and the speed-up takes effect after the first `generate` call.
216+
217+
The following table shows a performance comparison on an A100 for consequent `generate` calls:
195218

219+
| | Tokens/sec | Memory |
220+
| --------------- | ---------- | ------- |
221+
| Llama3Model | 156 | 3.12 GB |
222+
| Llama3ModelFast | 159 | 2.84 GB |

0 commit comments

Comments
 (0)