Skip to content

Commit 828a47c

Browse files
[turbine-llm] Add paged KV cache and llama variants to use it. (#486)
This is still just a reference implementation that linearizes in and out of the cache. I expect that the compiler will do some of this well enough with fusion that it will be fine, but we can replace hot areas with custom ops.
1 parent a616229 commit 828a47c

File tree

13 files changed

+1425
-98
lines changed

13 files changed

+1425
-98
lines changed

llm/scripts/validate_llama_model.py renamed to llm/scripts/validate_llama_ref_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
import sys
1414

15+
import torch
16+
1517
from turbine_llm.config import *
1618
from turbine_llm.data import *
17-
from turbine_llm.models.llama import *
19+
from turbine_llm.models.llama_ref import *
1820

1921

2022
def main(args: list[str]):
@@ -32,6 +34,12 @@ def main(args: list[str]):
3234
)
3335
print(f" : tokens = {tokens}")
3436

37+
# Decode a step.
38+
print("Decoding...")
39+
print(tokens.shape, tokens)
40+
decode_token = model.forward(tokens, start_index=12, local_kv_cache=kv_cache)
41+
print(f" : decode tokens = {decode_token}")
42+
3543

3644
if __name__ == "__main__":
3745
sys.exit(main(sys.argv[1:]))
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import sys
8+
9+
import torch
10+
11+
from turbine_llm.config import *
12+
from turbine_llm.data import *
13+
from turbine_llm.models.llama import *
14+
15+
16+
def main(args: list[str]):
17+
torch.no_grad().__enter__()
18+
config = load_gguf_file(args[0])
19+
hp = LlamaHParams.from_gguf_props(config.properties)
20+
model = PagedLlamaModelV1(config.root_theta, hp)
21+
cache_state = model.cache.allocate(128, torch.float32)
22+
start_index = 0
23+
next_batch = torch.tensor(
24+
[
25+
[
26+
1,
27+
1059,
28+
31871,
29+
1217,
30+
322,
31+
266,
32+
3682,
33+
6075,
34+
31902,
35+
13,
36+
31849,
37+
31871,
38+
0,
39+
0,
40+
0,
41+
0,
42+
]
43+
+ 48 * [0],
44+
[
45+
1,
46+
1059,
47+
31871,
48+
1217,
49+
322,
50+
31871,
51+
0,
52+
0,
53+
0,
54+
0,
55+
0,
56+
0,
57+
0,
58+
0,
59+
0,
60+
0,
61+
]
62+
+ 48 * [0],
63+
64 * [0],
64+
64 * [0],
65+
]
66+
)
67+
assert next_batch.shape[1] % model.cache.block_seq_stride == 0
68+
seq_block_ids = torch.tensor(
69+
[
70+
[127, 0, 0, 0],
71+
[126, 0, 0, 0],
72+
[0, 0, 0, 0],
73+
[0, 0, 0, 0],
74+
]
75+
)
76+
77+
# Important: Do not use a sequence length of 0 for empty batch slots
78+
# as it will cause softmax to nan due to a mask of all -inf. This then
79+
# propagates and causes badness.
80+
seq_lens = torch.tensor([12, 6, 1, 1])
81+
82+
attention_mask = model.attention_mask(
83+
model.input_mask(seq_lens, next_batch.shape[1]),
84+
dtype=torch.float32,
85+
)
86+
87+
print(f"Step {start_index}")
88+
logits = model.prefill(
89+
next_batch,
90+
attention_mask=attention_mask,
91+
seq_block_ids=seq_block_ids,
92+
cache_state=cache_state,
93+
)
94+
# TODO: Normalize the output of extract_tokens_from_logits into
95+
# tensor [bs, 1].
96+
tokens = torch.tensor(model.extract_tokens_from_logits(logits, seq_lens)).unsqueeze(
97+
1
98+
)
99+
print(f" : tokens = {tokens}")
100+
print(f" : cache[127] = {cache_state[0][127]}")
101+
print(f" : cache[126] = {cache_state[0][126]}")
102+
print(f" : cache[0] = {cache_state[0][0]}")
103+
print(f" : cache[1] = {cache_state[0][1]}")
104+
105+
# Decode a step.
106+
print("Decoding...")
107+
print(tokens.shape, tokens)
108+
start_positions = torch.tensor([12, 6, 0, 0])
109+
seq_lens = seq_lens + 1
110+
decode_attention_mask = model.decode_attention_mask(
111+
model.input_mask(
112+
seq_lens,
113+
seq_block_ids.shape[1] * model.cache.block_seq_stride,
114+
),
115+
dtype=torch.float32,
116+
)
117+
logits = model.decode(
118+
tokens,
119+
attention_mask=decode_attention_mask,
120+
start_positions=start_positions,
121+
seq_block_ids=seq_block_ids,
122+
read_cache_state=cache_state,
123+
write_cache_state=cache_state,
124+
)
125+
tokens = torch.tensor(
126+
model.extract_tokens_from_logits(logits, [1, 1, 1, 1])
127+
).unsqueeze(1)
128+
print(f" : tokens = {tokens}")
129+
print(f" : cache[127] = {cache_state[0][127]}")
130+
print(f" : cache[126] = {cache_state[0][126]}")
131+
print(f" : cache[0] = {cache_state[0][0]}")
132+
print(f" : cache[1] = {cache_state[0][1]}")
133+
134+
# from turbine_llm.models import llama
135+
# print(f"+++PREFILL XK = {llama.DEBUG_PREFILL_XK.shape}\n{llama.DEBUG_PREFILL_XK}")
136+
# print(f"+++DECODE XK = {llama.DEBUG_DECODE_XK.shape}\n{llama.DEBUG_DECODE_XK}")
137+
# torch.testing.assert_close(llama.DEBUG_PREFILL_XK, llama.DEBUG_DECODE_XK)
138+
139+
def save_prefill_module(model):
140+
from shark_turbine.importers.fx_importer import FxImporter
141+
from iree.compiler.ir import AsmState
142+
143+
importer = FxImporter()
144+
# asm_state = AsmState(importer.module_op)
145+
146+
print("Generating FX graph")
147+
148+
class InferenceModule(torch.nn.Module):
149+
def __init__(self):
150+
super().__init__()
151+
self.add_module("prefill", model)
152+
153+
def forward(self, next_batch, attention_mask, seq_block_ids, *cache_state):
154+
return self.prefill.prefill(
155+
next_batch,
156+
attention_mask=attention_mask,
157+
seq_block_ids=seq_block_ids,
158+
cache_state=list(cache_state),
159+
)
160+
161+
infmod = InferenceModule()
162+
prog = torch.export.export(
163+
infmod, (next_batch, attention_mask, seq_block_ids) + tuple(cache_state)
164+
)
165+
166+
print(f"FX prog:", prog)
167+
importer.import_program(prog, func_name="prefill")
168+
output_file = "/tmp/prefill.mlirbc"
169+
print("Saving to:", output_file)
170+
with open(output_file, "wb") as f:
171+
importer.module_op.write_bytecode(f)
172+
173+
# save_prefill_module()
174+
175+
176+
if __name__ == "__main__":
177+
sys.exit(main(sys.argv[1:]))

llm/turbine_llm/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

llm/turbine_llm/data/gguf/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _wrap_tensor(
7676
raise ValueError(f"Unsupported gguf tensor type: {type_name}")
7777

7878

79-
def load_gguf_file(gguf_path: Union[str, os.PathLike]):
79+
def load_gguf_file(gguf_path: Union[str, os.PathLike]) -> Dataset:
8080
reader = GGUFReader(gguf_path)
8181
logger.info(
8282
"Loading gguf file %s (%d fields, %d tensors)",

0 commit comments

Comments
 (0)