Skip to content

Commit 4f7d965

Browse files
[turbine-llm] Enable dynamic FX tracing of the paged llama model. (#516)
This was a little rough but got through it. Key changes involve removing data-dependent views. As per previous, I believe we still have a bug in the per step cache management, and we will need to do a detailed per-step comparison of the cache with a reference.
1 parent 1263faf commit 4f7d965

File tree

5 files changed

+171
-43
lines changed

5 files changed

+171
-43
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Inference support for the PagedLLMV1 protocol of models."""
8+
9+
import math
10+
import sys
11+
12+
import torch
13+
14+
from shark_turbine.aot import (
15+
FxProgramsBuilder,
16+
)
17+
18+
from ..data.gguf import load_gguf_file
19+
from ..config.llm_configs import LlamaHParams
20+
21+
# TODO: Should be using a base class with the protocol supported.
22+
from ..models.llama import PagedLlamaModelV1
23+
24+
25+
def main(args: list[str]):
26+
try:
27+
(gguf_path,) = args
28+
except IndexError:
29+
raise RuntimeError(f"Expected <gguf_path>")
30+
31+
dataset = load_gguf_file(gguf_path)
32+
33+
hp = LlamaHParams.from_gguf_props(dataset.properties)
34+
model = PagedLlamaModelV1(dataset.root_theta, hp)
35+
36+
# Unrolling cache updates by batch row makes dynamo sad without an
37+
# override. There may be a better way to do this.
38+
import torch._dynamo.config as dynamo_config
39+
40+
dynamo_config.max_loop_unroll_nodes = 0
41+
42+
fxb = FxProgramsBuilder(model)
43+
44+
def generate_batch_prefill(bs: int):
45+
tokens = torch.empty(bs, 64, dtype=torch.int64)
46+
seq_lens = torch.empty(bs, dtype=torch.int64)
47+
seq_block_ids = torch.empty(bs, 4, dtype=torch.int64)
48+
cache_state = model.cache.allocate(128, torch.float32)
49+
block_dim = torch.export.Dim("block", max=2047 // 16)
50+
sl_dim = 16 * block_dim
51+
page_dim = torch.export.Dim("page")
52+
dynamic_shapes = {
53+
"tokens": {1: sl_dim},
54+
"seq_lens": {},
55+
"seq_block_ids": {1: block_dim},
56+
"cache_state": [{0: page_dim}],
57+
}
58+
59+
@fxb.export_program(
60+
name=f"prefill_bs{bs}",
61+
args=(tokens, seq_lens, seq_block_ids, cache_state),
62+
dynamic_shapes=dynamic_shapes,
63+
)
64+
def _(model, tokens, seq_lens, seq_block_ids, cache_state):
65+
sl = tokens.shape[1]
66+
input_mask = model.input_mask(seq_lens, sl)
67+
attention_mask = model.attention_mask(input_mask, dtype=torch.float32)
68+
logits = model.prefill(
69+
tokens,
70+
attention_mask=attention_mask,
71+
seq_block_ids=seq_block_ids,
72+
cache_state=cache_state,
73+
)
74+
return logits
75+
76+
def generate_batch_decode(bs: int):
77+
tokens = torch.ones(bs, 1, dtype=torch.int64)
78+
seq_lens = torch.ones(bs, dtype=torch.int64)
79+
start_positions = torch.ones(bs, dtype=torch.int64)
80+
seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64)
81+
cache_state = model.cache.allocate(128, torch.float32)
82+
block_dim = torch.export.Dim("block", max=2047 // 16)
83+
page_dim = torch.export.Dim("page")
84+
dynamic_shapes = {
85+
"tokens": {},
86+
"seq_lens": {},
87+
"start_positions": {},
88+
"seq_block_ids": {1: block_dim},
89+
"cache_state": [{0: page_dim}],
90+
}
91+
92+
@fxb.export_program(
93+
name=f"decode_bs{bs}",
94+
args=(
95+
tokens,
96+
seq_lens,
97+
start_positions,
98+
seq_block_ids,
99+
cache_state,
100+
),
101+
dynamic_shapes=dynamic_shapes,
102+
)
103+
def _(
104+
model,
105+
tokens,
106+
seq_lens,
107+
start_positions,
108+
seq_block_ids,
109+
cache_state,
110+
):
111+
input_mask = model.input_mask(
112+
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
113+
)
114+
attention_mask = model.decode_attention_mask(
115+
input_mask, dtype=torch.float32
116+
)
117+
logits = model.decode(
118+
tokens,
119+
attention_mask=attention_mask,
120+
start_positions=start_positions,
121+
seq_block_ids=seq_block_ids,
122+
read_cache_state=cache_state,
123+
write_cache_state=cache_state,
124+
)
125+
return logits
126+
127+
generate_batch_prefill(16)
128+
generate_batch_decode(16)
129+
print("GENERATED!")
130+
131+
for name, ep in fxb.programs.items():
132+
print(f"EXPORT {name}:\n{ep}")
133+
134+
135+
if __name__ == "__main__":
136+
main(sys.argv[1:])

llm/turbine_llm/layers/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def assert_not_nan(self, *ts: torch.Tensor):
3434
Must be enabled via a global switch as this kind of checking is not
3535
accelerator or compilation friendly.
3636
"""
37-
for t in ts:
38-
if torch.isnan(t).any():
39-
raise AssertionError(f"Tensor contains nans! {t}")
37+
if debugging.flags.enable_nan_checks:
38+
for t in ts:
39+
if torch.isnan(t).any():
40+
raise AssertionError(f"Tensor contains nans! {t}")
4041

4142

4243
class ThetaLayer(BaseLayer):

llm/turbine_llm/layers/kv_cache.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,19 @@ def write_timestep(
181181
assert len(cache_partitions) == self.cache_partition_count
182182
for i in range(bs):
183183
position = seq_positions[i]
184-
page_id = page_ids[i, position // self.block_seq_stride]
184+
# TODO: Let's clamp to the allowable range so that we don't need
185+
# an assert.
186+
page_id = page_ids[i, :].index_select(0, position // self.block_seq_stride)
185187
page_offset = position % self.block_seq_stride
186188
for partition_index in range(self.cache_partition_count):
187189
cache_partition = cache_partitions[partition_index]
188-
page_table[
189-
page_id, transformer_block_index, partition_index, page_offset
190-
] = cache_partition[i, 0]
190+
indices = (
191+
page_id,
192+
torch.tensor([transformer_block_index]),
193+
torch.tensor([partition_index]),
194+
page_offset.unsqueeze(0),
195+
)
196+
page_table.index_put_(indices=indices, values=cache_partition[i, 0])
191197

192198
def write(
193199
self,

llm/turbine_llm/models/llama.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def __init__(self, theta: Theta, hp: LlamaHParams):
9595
theta("blk", n),
9696
block_index=n,
9797
cache=self.cache,
98-
embedding=self.attention_embedding,
9998
head_count=hp.attention_head_count,
10099
head_dim=attn_head_dim,
101100
head_count_kv=hp.attention_head_count_kv,
@@ -125,6 +124,7 @@ def prefill(
125124
self.trace_tensor(f"llama.attn_block.{block_idx}.input", h)
126125
h = block(
127126
h,
127+
embedding=self.attention_embedding,
128128
start_index=0,
129129
attention_mask=attention_mask,
130130
write_cache_state=cache_state,
@@ -189,6 +189,7 @@ def decode(
189189
h = block(
190190
h,
191191
start_positions=start_positions,
192+
embedding=self.attention_embedding,
192193
embedding_batch_mask=embedding_batch_mask,
193194
attention_mask=attention_mask,
194195
read_cache_state=read_cache_state,
@@ -222,7 +223,6 @@ def __init__(
222223
head_count: int,
223224
head_dim: int,
224225
head_count_kv: int,
225-
embedding: RotaryEmbeddingLayer,
226226
rms_epsilon: float,
227227
):
228228
super().__init__(theta)
@@ -242,7 +242,6 @@ def __init__(
242242

243243
self.block_index = block_index
244244
self.cache = cache
245-
self.embedding = embedding
246245
self.head_count = head_count
247246
self.head_dim = head_dim
248247
self.head_count_kv = head_count_kv
@@ -251,6 +250,7 @@ def forward(
251250
self,
252251
h: torch.Tensor,
253252
*,
253+
embedding: RotaryEmbeddingLayer,
254254
# [bs, batch_seq_len // block_seq_stride]
255255
seq_block_ids: torch.Tensor,
256256
start_index: Optional[int] = None,
@@ -280,9 +280,9 @@ def forward(
280280
# Fast path to start_index based embedding lookup if available.
281281
# Falls back to a slower position based index lookup.
282282
if start_index is not None:
283-
xq, xk = self.embedding.forward(xq=xq, xk=xk, start_index=start_index)
283+
xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
284284
else:
285-
xq, xk = self.embedding.apply_batched_mask(
285+
xq, xk = embedding.apply_batched_mask(
286286
xq=xq, xk=xk, mask=embedding_batch_mask
287287
)
288288

@@ -321,6 +321,19 @@ def forward(
321321

322322
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride
323323

324+
if write_cache_state:
325+
# Write our one updated cache row into the cache.
326+
self.cache.write_timestep(
327+
write_cache_state,
328+
cache_partitions=[
329+
xk_cache_update,
330+
xv_cache_update,
331+
],
332+
transformer_block_index=self.block_index,
333+
seq_positions=start_positions + 1,
334+
page_ids=seq_block_ids,
335+
)
336+
324337
# Restore from the cache.
325338
self.cache.read(
326339
read_cache_state,
@@ -332,18 +345,6 @@ def forward(
332345
page_ids=seq_block_ids,
333346
)
334347

335-
# self.trace_tensor("DECODE.KV.ACTUAL", xk_temp)
336-
# Now restore the newly computed position into the xk/xv view we
337-
# are operating on. This will also be done later when updating the
338-
# cache, but we separate it here to avoid creating a data
339-
# dependency. Since the batch size is static, we do a static loop
340-
# in order to simplify indexing and keeps us from needing to
341-
# deal with masking.
342-
for i in range(bs):
343-
row_start_pos = start_positions[i]
344-
xk_temp[i, row_start_pos : row_start_pos + 1, :, :] = xk[i, ...]
345-
xv_temp[i, row_start_pos : row_start_pos + 1, :, :] = xv[i, ...]
346-
347348
# For computation, we create a subview of the xk/xv tensors to have
348349
# a sequence length covering the blocked size. This must include
349350
# the newly added row (the caller is responsible for ensuring that
@@ -352,25 +353,6 @@ def forward(
352353
xk = xk_temp[:, 0:kv_seq_len, ...]
353354
xv = xv_temp[:, 0:kv_seq_len, ...]
354355

355-
if write_cache_state:
356-
# Write our one updated cache row. We currently do this apart
357-
# from the linearization step because it lets us have aliased
358-
# cache states. We may need to revisit this if we can support
359-
# a cache write-read in the same sequence.
360-
# In that case, this would go prior to the read.
361-
# self.trace_tensor("decode.xk_cache_update", xk_cache_update)
362-
# self.trace_tensor("decode.xv_cache_update", xv_cache_update)
363-
self.cache.write_timestep(
364-
write_cache_state,
365-
cache_partitions=[
366-
xk_cache_update,
367-
xv_cache_update,
368-
],
369-
transformer_block_index=self.block_index,
370-
seq_positions=start_positions + 1,
371-
page_ids=seq_block_ids,
372-
)
373-
374356
# Tranpose into [bs, heads, sl, dim]
375357
xq = xq.transpose(1, 2)
376358
keys = xk.transpose(1, 2)

llm/turbine_llm/utils/debugging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
@dataclass
2727
class DebugFlags:
2828
enable_tensor_trace: bool = False
29+
enable_nan_checks: bool = False
2930

3031
def set(self, part: str):
3132
m = re.match(SETTING_PART_PATTERN, part)
@@ -38,6 +39,8 @@ def set(self, part: str):
3839

3940
if name == "tensor_trace":
4041
self.enable_tensor_trace = logical_sense
42+
elif name == "enable_nan_checks":
43+
self.enable_nan_checks = logical_sense
4144
else:
4245
logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)
4346

0 commit comments

Comments
 (0)