Skip to content

Commit 9bbd859

Browse files
authored
Add the doc of mlp block (#13)
* add the doc of mlp Signed-off-by: Connor1996 <zbk602423539@gmail.com> * update index Signed-off-by: Connor1996 <zbk602423539@gmail.com> --------- Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent e4aba84 commit 9bbd859

File tree

9 files changed

+136
-15
lines changed

9 files changed

+136
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ You may join skyzh's Discord server and study with the tiny-llm community.
2929
| 1.1 | Attention ||||
3030
| 1.2 | RoPE ||||
3131
| 1.3 | Grouped Query Attention ||||
32-
| 1.4 | RMSNorm and MLP || 🚧 | 🚧 |
32+
| 1.4 | RMSNorm and MLP || | |
3333
| 1.5 | Transformer Block || 🚧 | 🚧 |
3434
| 1.6 | Load the Model || 🚧 | 🚧 |
3535
| 1.7 | Generate Responses (aka Decoding) ||| 🚧 |

book/src/glossary.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44
- [Multi Head Attention](./week1-01-attention.md)
55
- [Linear](./week1-01-attention.md)
66
- [Rotary Positional Encoding](./week1-02-positional-encodings.md)
7+
- [Grouped Query Attention](./week1-03-gqa.md)
8+
- [RMSNorm](./week1-04-rmsnorm-and-mlp.md)
9+
- [SiLU](./week1-04-rmsnorm-and-mlp.md)
10+
- [SwiGLU](./week1-04-rmsnorm-and-mlp.md)
11+
- [MLP](./week1-04-rmsnorm-and-mlp.md)
712

813
{{#include copyright.md}}

book/src/week1-01-attention.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ src/tiny_llm/attention.py
100100
Implement `MultiHeadAttention`. The layer takes a batch of vectors, maps it through the K, V, Q weight matrixes, and use the attention function we implemented in task 1 to compute the result. The output needs to be mapped using the O
101101
weight matrix.
102102

103-
You will also need to implement the `linear` function first. For `linear`, it takes a tensor of the shape `N.. x I`, a weight matrix of the shape `O x I`, and a bias vector of the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
103+
You will also need to implement the `linear` function in `basics.py` first. For `linear`, it takes a tensor of the shape `N.. x I`, a weight matrix of the shape `O x I`, and a bias vector of the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
104104

105105
For the `MultiHeadAttention` layer, the input tensors `query`, `key`, `value` have the shape `N x L x E`, where `E` is the dimension of the
106106
embedding for a given token in the sequence. The `K/Q/V` weight matrixes will map the tensor into key, value, and query
@@ -123,9 +123,9 @@ H is num_heads
123123
D is head_dim
124124
L is seq_len, in PyTorch API it's S (source len)
125125
126-
W_q/W_k/W_v: E x (H x D)
126+
w_q/w_k/w_v: E x (H x D)
127127
output/input: N x L x E
128-
W_o: (H x D) x E
128+
w_o: (H x D) x E
129129
```
130130

131131
At the end of the day, you should be able to pass the following tests:

book/src/week1-04-rmsnorm-and-mlp.md

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ In day 4, we will implement two crucial components of the Qwen2 Transformer arch
1212

1313
## Task 1: Implement `RMSNorm`
1414

15-
You will need to implement the `RMSNorm` layer in:
15+
In this task, we will implement the `RMSNorm` layer.
1616

1717
```
1818
src/tiny_llm/layer_norm.py
@@ -55,6 +55,64 @@ pdm run test -k week_1_day_4_task_1 -v
5555

5656
## Task 2: Implement the MLP Block
5757

58-
TBD...
58+
In this task, we will implement the MLP block named `Qwen2MLP`.
59+
60+
```
61+
src/tiny_llm/qwen2_week1.py
62+
```
63+
64+
The original Transformer model utilized a simple Feed-Forward Network (FFN) within each block. This FFN typically consisted of two linear transformations with a ReLU activation in between, applied position-wise.
65+
66+
Modern Transformer architectures, including Qwen2, often employ more advanced FFN variants for improved performance. Qwen2 uses a specific type of Gated Linear Unit (GLU) called SwiGLU.
67+
68+
**📚 Readings**
69+
* [Attention is All You Need (Transformer Paper, Section 3.3 "Position-wise Feed-Forward Networks")](https://arxiv.org/abs/1706.03762)
70+
* [GLU Paper(Language Modeling with Gated Convolutional Networks)](https://arxiv.org/pdf/1612.08083)
71+
* [SilU(Swish) activation function](https://arxiv.org/pdf/1710.05941)
72+
* [SwiGLU Paper(GLU Variants Improve Transformer)](https://arxiv.org/abs/2002.05202v1)
73+
* [PyTorch SiLU documentation](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html)
74+
* [Qwen2 layers implementation in mlx-lm (includes MLP)](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py)
75+
76+
Essientially, SwiGLU is a combination of GLU and the SiLU (Sigmoid Linear Unit) activation function:
77+
- GLU is a gating mechanism that allows the model to learn which parts of the input to focus on. It typically involves an element-wise product of two linear projections of the input, one of which might be passed through an activation function. Compared to ReLU used in the original FFN, GLU can help the model learn more complex relationships in the data, deciding which features to keep and which to discard.
78+
- SiLU (Sigmoid Linear Unit) is a smooth, non-monotonic activation function that has been shown to perform well in various deep learning tasks. Compared to ReLU and sigmoid used in GLU, it is fully differentiable without the zero-gradient “dead zones”, retains non-zero output even for negative inputs.
79+
80+
You need to implement the `silu` function in `basics.py` first. For `silu`, it takes a tensor of the shape `N.. x I` and returns a tensor of the same shape.
81+
The `silu` function is defined as:
82+
$$
83+
\text{SiLU}(x) = x * \text{sigmoid}(x) = \frac{x}{1 + e^{-x}}
84+
$$
85+
86+
87+
Then implement `Qwen2MLP`. The structure for Qwen2's MLP block is:
88+
* A gate linear projection ($W_{gate}$).
89+
* An up linear projection ($W_{up}$).
90+
* A SiLU activation function applied to the output of $W_{gate}$.
91+
* An element-wise multiplication of the SiLU-activated $W_{gate}$ output and the $W_{up}$ output. This forms the "gated" part.
92+
* A final down linear projection ($W_{down}$).
93+
94+
This can be expressed as:
95+
$$
96+
\text{MLP}(x) = (\text{SiLU}(W_{gate}(x)) \odot W_{up}(x))W_{down}
97+
$$
98+
Where $\odot$ denotes element-wise multiplication. All linear projections in Qwen2's MLP are typically implemented without bias.
99+
100+
```
101+
N.. is zero or more dimensions for batches
102+
E is hidden_size (embedding dimension of the model)
103+
I is intermediate_size (dimension of the hidden layer in MLP)
104+
L is the sequence length
105+
106+
input: N.. x L x E
107+
w_gate: I x E
108+
w_up: I x E
109+
w_down: E x I
110+
output: N.. x L x E
111+
```
112+
113+
You can test your implementation by running:
114+
```bash
115+
pdm run test -k week_1_day_4_task_2 -v
116+
```
59117

60118
{{#include copyright.md}}

src/tiny_llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from .layer_norm import *
55
from .positional_encoding import *
66
from .quantize import *
7-
from .qwen2_week1 import Qwen2ModelWeek1
7+
from .qwen2_week1 import *
88
from .generate import *
99
from .qwen2_week2 import Qwen2ModelWeek2

src/tiny_llm_ref/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from .quantize import *
77
from .generate import *
88
from .kv_cache import *
9-
from .qwen2_week1 import Qwen2ModelWeek1
9+
from .qwen2_week1 import *
1010
from .qwen2_week2 import Qwen2ModelWeek2

tests/test_layer_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
@pytest.mark.parametrize("target", ["torch", "mlx"])
1010
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
1111
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
12-
def test_rms_norm_week_1_day_4_task_1(stream: mx.Stream, precision: np.dtype, target: str):
12+
def test_rms_norm_week_1_day_4_task_1(
13+
stream: mx.Stream, precision: np.dtype, target: str
14+
):
1315
SIZE = 100
1416
SIZE_Y = 111
1517
with mx.stream(stream):

tests/test_qwen2_mlp.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import mlx.core as mx
2+
import pytest
3+
from mlx_lm.models import qwen2
4+
import numpy as np
5+
6+
from .tiny_llm_base import *
7+
from .utils import *
8+
9+
# Define different dimension parameters for testing
10+
DIM_PARAMS = [
11+
{"batch_size": 1, "seq_len": 5, "dim": 4, "hidden_dim": 8, "id": "small_dims"},
12+
{"batch_size": 2, "seq_len": 16, "dim": 32, "hidden_dim": 64, "id": "large_dims"},
13+
{
14+
"batch_size": 1,
15+
"seq_len": 1,
16+
"dim": 128,
17+
"hidden_dim": 256,
18+
"id": "single_token",
19+
},
20+
]
21+
DIM_PARAMS_IDS = [d["id"] for d in DIM_PARAMS]
22+
23+
24+
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
25+
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
26+
@pytest.mark.parametrize("dims", DIM_PARAMS, ids=DIM_PARAMS_IDS)
27+
def test_qwen2_mlp_week_1_day_4_task_2(
28+
stream: mx.Stream, precision: np.dtype, dims: dict
29+
):
30+
BATCH_SIZE, SEQ_LEN, DIM, HIDDEN_DIM = (
31+
dims["batch_size"],
32+
dims["seq_len"],
33+
dims["dim"],
34+
dims["hidden_dim"],
35+
)
36+
37+
with mx.stream(stream):
38+
mx_precision = np_type_to_mx_type(precision)
39+
x = mx.random.uniform(shape=(BATCH_SIZE, SEQ_LEN, DIM)).astype(mx_precision)
40+
w_gate = mx.random.uniform(shape=(HIDDEN_DIM, DIM)).astype(mx_precision)
41+
w_up = mx.random.uniform(shape=(HIDDEN_DIM, DIM)).astype(mx_precision)
42+
w_down = mx.random.uniform(shape=(DIM, HIDDEN_DIM)).astype(mx_precision)
43+
44+
user_mlp = Qwen2MLP(
45+
dim=DIM, hidden_dim=HIDDEN_DIM, w_gate=w_gate, w_up=w_up, w_down=w_down
46+
)
47+
user_output = user_mlp(x)
48+
49+
reference_mlp = qwen2.MLP(dim=DIM, hidden_dim=HIDDEN_DIM)
50+
reference_mlp.gate_proj.weight = w_gate
51+
reference_mlp.up_proj.weight = w_up
52+
reference_mlp.down_proj.weight = w_down
53+
reference_output = reference_mlp(x)
54+
55+
assert_allclose(user_output, reference_output, precision)

tests/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ def assert_allclose(
3131
atol = atol or 1.0e-5
3232
assert a.shape == b.shape, f"shape mismatch: {a.shape} vs {b.shape}"
3333
if not np.allclose(a, b, rtol=rtol, atol=atol):
34-
print("a=", a)
35-
print("b=", b)
36-
diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol))
37-
print("diff_a=", a * diff)
38-
print("diff_b=", b * diff)
39-
assert False, f"result mismatch"
34+
with np.printoptions(precision=3, suppress=True):
35+
print("a=", a)
36+
print("b=", b)
37+
diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol))
38+
print("diff_a=", a * diff)
39+
print("diff_b=", b * diff)
40+
assert False, f"result mismatch"
4041

4142

4243
def np_type_to_mx_type(np_type: np.dtype) -> mx.Dtype:

0 commit comments

Comments
 (0)