Skip to content

Commit 0c1ae4b

Browse files
committed
fix attention mask; chunked prefill working now
Signed-off-by: Alex Chi <iskyzh@gmail.com>
1 parent 7139856 commit 0c1ae4b

File tree

6 files changed

+33
-9
lines changed

6 files changed

+33
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ You may join skyzh's Discord server and study with the tiny-llm community.
3939
| 2.4 | Flash Attention 2 - CPU || 🚧 | 🚧 |
4040
| 2.5 | Flash Attention 2 - GPU || 🚧 | 🚧 |
4141
| 2.6 | Continuous Batching || 🚧 | 🚧 |
42-
| 2.7 | Chunked Prefill | 🚧 | 🚧 | 🚧 |
42+
| 2.7 | Chunked Prefill | | 🚧 | 🚧 |
4343
| 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 |
4444
| 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 |
4545
| 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 |

batch-main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
Shanghai[a] is a direct-administered municipality and the most populous urban area in China. The city is located on the Chinese shoreline on the southern estuary of the Yangtze River, with the Huangpu River flowing through it. The population of the city proper is the second largest in the world after Chongqing, with around 24.87 million inhabitants in 2023, while the urban area is the most populous in China, with 29.87 million residents. As of 2022, the Greater Shanghai metropolitan area was estimated to produce a gross metropolitan product (nominal) of nearly 13 trillion RMB ($1.9 trillion).[13] Shanghai is one of the world's major centers for finance, business and economics, research, science and technology, manufacturing, transportation, tourism, and culture. The Port of Shanghai is the world's busiest container port.
2121
"""
2222

23+
shanghai_wikipedia += "Based on the previous information, "
24+
2325
prompts = [
2426
shanghai_wikipedia + "Where is Shanghai?",
2527
shanghai_wikipedia + "How much is the population of Shanghai?",

book/src/week2-overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/sdpa_vecto
2020

2121
attention mask why
2222
https://www.shashankshekhar.com/blog/apple-metal-vs-nvidia-cuda
23+
https://arxiv.org/pdf/2308.16369

src/tiny_llm_ref/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def scaled_dot_product_attention_grouped(
4040
scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor
4141
if mask is not None:
4242
if mask == "causal":
43-
mask = 1 - mx.tril(mx.ones((L, S)))
44-
mask = mx.where(mask, mx.array(-mx.inf), mx.array(0))
43+
mask = mx.tril(mx.ones((L, S)), k=S - L)
44+
mask = mx.where(mask, mx.array(0), mx.array(-mx.inf))
4545
scores = scores + mask
4646
else:
4747
mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1])

src/tiny_llm_ref/generate.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ def _step(model, y, offset, kv_cache):
4747
# prefill with the prompt
4848
tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
4949
offset = 0
50+
prefill_max = 64
51+
total_tokens = tokens.size
52+
while tokens.size > prefill_max:
53+
token, _ = _step(model, tokens[:prefill_max], offset, kv_cache)
54+
for i in kv_cache:
55+
mx.eval(i.key_values[0])
56+
mx.eval(i.key_values[1])
57+
offset += prefill_max
58+
tokens = tokens[prefill_max:]
59+
print(f"Prefill progress: {offset}/{total_tokens}", flush=True)
5060
detokenizer = tokenizer.detokenizer
5161
detokenizer.reset()
5262
# generate/decode
@@ -72,7 +82,7 @@ def _step(model, y, offsets, kv_cache):
7282

7383
class PrefillRequest:
7484
def __init__(
75-
self, model: any, tokenizer: TokenizerWrapper, prompt: str, max_step: int = 64
85+
self, model: any, tokenizer: TokenizerWrapper, prompt: str, max_step: int = 16
7686
):
7787
self.prompt = prompt
7888
self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)]
@@ -92,16 +102,19 @@ def prefill(self):
92102
[self.offset],
93103
self.kv_cache,
94104
)
95-
mx.eval(token)
96105
self.offset += tokens_to_prefill
106+
for i in self.kv_cache:
107+
mx.eval(i.key_values[0])
108+
mx.eval(i.key_values[1])
97109
if self.offset == self.prefill_tokens.size:
110+
mx.eval(token)
98111
return token, self.kv_cache, self.offset
99112
else:
100113
return None
101114

102115

103116
def batch_generate(
104-
model: any, tokenizer: TokenizerWrapper, prompts: list[str], max_seq_len=64
117+
model: any, tokenizer: TokenizerWrapper, prompts: list[str], max_seq_len=512
105118
):
106119
MAX_REQUESTS = 5
107120
is_idle = [True] * MAX_REQUESTS
@@ -110,7 +123,7 @@ def batch_generate(
110123
offsets = mx.array([0] * MAX_REQUESTS)
111124
detokenizers = [None] * MAX_REQUESTS
112125
kv_cache = [
113-
BatchingKvCache(max_active_requests=MAX_REQUESTS, max_seq_len=64)
126+
BatchingKvCache(max_active_requests=MAX_REQUESTS, max_seq_len=max_seq_len)
114127
for _ in range(model.num_hidden_layers)
115128
]
116129
result = []
@@ -166,7 +179,10 @@ def batch_generate(
166179
offsets[i] = offset
167180
break
168181
else:
169-
print("Still prefilling the request", flush=True)
182+
print(
183+
f"Still prefilling the request: {pending_prefill_requests.offset}/{pending_prefill_requests.prefill_tokens.size}",
184+
flush=True,
185+
)
170186

171187
if not all(is_idle):
172188
next_tokens = mx.array(next_tokens)
@@ -194,4 +210,6 @@ def batch_generate(
194210
f"(In Progress) {prompt_idx[i]}: " + detokenizers[i].text,
195211
flush=True,
196212
)
213+
else:
214+
print("No requests in progress", flush=True)
197215
return result

src/tiny_llm_ref/qwen2_week2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def __call__(
6262
projection_v = quantized_linear(x, self.wv, bias=self.bv).reshape(
6363
B, L, self.num_kv_heads, self.head_dim
6464
)
65-
offset_slice = [slice(int(i), int(i + L)) for i in offsets]
65+
if isinstance(offsets, int):
66+
offset_slice = [slice(int(offsets), int(offsets + L))]
67+
else:
68+
offset_slice = [slice(int(i), int(i + L)) for i in offsets]
6669
projection_q = self.rope(projection_q, offset=offset_slice)
6770
projection_k = self.rope(projection_k, offset=offset_slice)
6871
projection_q = projection_q.transpose(0, 2, 1, 3)

0 commit comments

Comments
 (0)