Skip to content

Commit 24b1872

Browse files
authored
[models] Implement KV-cache to have own global for each layer, K, V. (#505)
In this implementation, we are splitting up the KV cache to have a global for each layer for each K/V. This will allow for cleaner indexing which will lead to simpler analysis and better perf down the line. One analysis that we need this for is to check that we have independent reads and writes to KV-cache. We need this to remove redundant copies of KV-cache. We can also do the same with integrated/big slab but would be much harder since it will require us to analyze dynamic striding and more complex ranges.
1 parent 2dc96a8 commit 24b1872

File tree

1 file changed

+117
-32
lines changed

1 file changed

+117
-32
lines changed

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,30 @@ def generate_schema(num_layers):
8080
return json.dumps(schema)
8181

8282

83-
def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim, num_layers):
83+
def slice_up_to_step(k_caches, v_caches, seq_step, heads, hidden_dim, num_layers):
8484
all_pkv_tensors = []
8585
for i in range(num_layers * 2):
8686
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
8787
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
88-
sliced = IREE.tensor_slice(
89-
global_pkv, i, 0, (0, seq_step), (0, heads), (0, hidden_dim)
90-
) # sequence context dim
88+
if i % 2 == 0:
89+
sliced = IREE.tensor_slice(
90+
k_caches["layer_idx"][i // 2],
91+
0,
92+
(0, seq_step),
93+
(0, heads),
94+
(0, hidden_dim),
95+
) # sequence context dim
96+
else:
97+
sliced = IREE.tensor_slice(
98+
v_caches["layer_idx"][i // 2],
99+
0,
100+
(0, seq_step),
101+
(0, heads),
102+
(0, hidden_dim),
103+
) # sequence context dim
91104
all_pkv_tensors.append(
92105
IREE.tensor_reshape(sliced, 1, seq_step, heads, hidden_dim)
93106
)
94-
95107
return all_pkv_tensors
96108

97109

@@ -139,9 +151,12 @@ def export_transformer_model(
139151
BATCH_SIZE = 1
140152
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
141153
global_pkv = torch.zeros(
142-
size=(NUM_LAYERS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
154+
size=(BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
143155
dtype=dtype,
144156
)
157+
kv_cache_structure = {
158+
"layer_idx": [abstractify(global_pkv) for _ in range(NUM_LAYERS)],
159+
}
145160

146161
mapper = {}
147162
if external_weights is not None:
@@ -163,29 +178,44 @@ class StateUpdateModule(CompiledModule):
163178
)
164179
else:
165180
params = export_parameters(mod)
166-
global_state = export_global(
167-
abstractify(global_pkv), uninitialized=True, mutable=True
168-
)
169181
global_seq_step = export_global(AbstractIndex, mutable=True)
182+
global_k_caches = export_global_tree(
183+
kv_cache_structure, uninitialized=True, mutable=True
184+
)
185+
global_v_caches = export_global_tree(
186+
kv_cache_structure, uninitialized=True, mutable=True
187+
)
170188

171189
def run_initialize(self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)):
172190
init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
173191
token, *state = self.initialize(x, constraints=init_const)
174192
self.global_seq_step = IREE.tensor_dim(
175193
state[0], 1
176194
) # ? dimension of arbitrarily 0th kv tensor
177-
for i in range(NUM_LAYERS * 2):
195+
for i in range(NUM_LAYERS):
178196
slice_of_state = IREE.tensor_reshape(
179-
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
197+
state[i * 2], 1, self.global_seq_step, HEADS, HIDDEN_DIM
180198
)
181-
self.global_state = IREE.tensor_update(
182-
self.global_state, slice_of_state, i, 0, 0, 0, 0
199+
self.global_k_caches["layer_idx"][i] = IREE.tensor_update(
200+
self.global_k_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0
201+
)
202+
for i in range(NUM_LAYERS):
203+
slice_of_state = IREE.tensor_reshape(
204+
state[i * 2 + 1], 1, self.global_seq_step, HEADS, HIDDEN_DIM
205+
)
206+
self.global_v_caches["layer_idx"][i] = IREE.tensor_update(
207+
self.global_v_caches["layer_idx"][i], slice_of_state, 0, 0, 0, 0
183208
)
184209
return token
185210

186211
def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
187212
state_arg = slice_up_to_step(
188-
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
213+
self.global_k_caches,
214+
self.global_v_caches,
215+
self.global_seq_step,
216+
HEADS,
217+
HIDDEN_DIM,
218+
NUM_LAYERS,
189219
)
190220
forw_const = (
191221
[state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ]
@@ -196,20 +226,33 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
196226
+ [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]]
197227
)
198228
token, *state_update = self.forward(x, *state_arg, constraints=forw_const)
199-
for i in range(NUM_LAYERS * 2):
229+
for i in range(NUM_LAYERS):
200230
update = IREE.tensor_reshape(
201-
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
231+
state_update[i * 2], 1, 1, HEADS, HIDDEN_DIM
202232
)
203-
self.global_state = IREE.tensor_update(
204-
self.global_state, update, i, 0, self.global_seq_step, 0, 0
233+
self.global_k_caches["layer_idx"][i] = IREE.tensor_update(
234+
self.global_k_caches["layer_idx"][i],
235+
update,
236+
0,
237+
self.global_seq_step,
238+
0,
239+
0,
240+
)
241+
for i in range(NUM_LAYERS):
242+
update = IREE.tensor_reshape(
243+
state_update[i * 2 + 1], 1, 1, HEADS, HIDDEN_DIM
244+
)
245+
self.global_v_caches["layer_idx"][i] = IREE.tensor_update(
246+
self.global_v_caches["layer_idx"][i],
247+
update,
248+
0,
249+
self.global_seq_step,
250+
0,
251+
0,
205252
)
206-
207253
self.global_seq_step = self.global_seq_step + 1
208254
return token
209255

210-
def get_global_state(self):
211-
return self.global_state
212-
213256
def get_seq_step(self):
214257
return self.global_seq_step
215258

@@ -239,7 +282,12 @@ def run_cached_initialize(
239282
self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
240283
):
241284
state_arg = slice_up_to_step(
242-
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM, NUM_LAYERS
285+
self.global_k_caches,
286+
self.global_v_caches,
287+
self.global_seq_step,
288+
HEADS,
289+
HIDDEN_DIM,
290+
NUM_LAYERS,
243291
)
244292
forw_const = (
245293
[x.dynamic_dim(1) < MAX_STEP_SEQ]
@@ -256,12 +304,29 @@ def run_cached_initialize(
256304
len_of_new_tokens = IREE.tensor_dim(
257305
state[0], 1
258306
) # ? dimension of arbitrarily 0th kv tensor
259-
for i in range(NUM_LAYERS * 2):
307+
for i in range(NUM_LAYERS):
308+
slice_of_state = IREE.tensor_reshape(
309+
state[i * 2], 1, len_of_new_tokens, HEADS, HIDDEN_DIM
310+
)
311+
self.global_k_caches["layer_idx"][i] = IREE.tensor_update(
312+
self.global_k_caches["layer_idx"][i],
313+
slice_of_state,
314+
0,
315+
self.global_seq_step,
316+
0,
317+
0,
318+
)
319+
for i in range(NUM_LAYERS):
260320
slice_of_state = IREE.tensor_reshape(
261-
state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM
321+
state[i * 2 + 1], 1, len_of_new_tokens, HEADS, HIDDEN_DIM
262322
)
263-
self.global_state = IREE.tensor_update(
264-
self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0
323+
self.global_v_caches["layer_idx"][i] = IREE.tensor_update(
324+
self.global_v_caches["layer_idx"][i],
325+
slice_of_state,
326+
0,
327+
self.global_seq_step,
328+
0,
329+
0,
265330
)
266331
self.global_seq_step = self.global_seq_step + len_of_new_tokens
267332
return token
@@ -291,17 +356,37 @@ def evict_kvcache_space(self):
291356
sink_size = 4
292357
window_size = 252
293358
most_recent_window = self.global_seq_step + (-window_size)
294-
for i in range(NUM_LAYERS * 2):
359+
for i in range(NUM_LAYERS):
295360
update_window_state = IREE.tensor_slice(
296-
self.global_state,
297-
i,
361+
self.global_k_caches["layer_idx"][i],
298362
0,
299363
(most_recent_window, window_size),
300364
(0, HEADS),
301365
(0, HIDDEN_DIM),
302366
) # sequence context dim
303-
self.global_state = IREE.tensor_update(
304-
self.global_state, update_window_state, i, 0, sink_size, 0, 0
367+
self.global_k_caches["layer_idx"][i] = IREE.tensor_update(
368+
self.global_k_caches["layer_idx"][i],
369+
update_window_state,
370+
0,
371+
sink_size,
372+
0,
373+
0,
374+
)
375+
for i in range(NUM_LAYERS):
376+
update_window_state = IREE.tensor_slice(
377+
self.global_v_caches["layer_idx"][i],
378+
0,
379+
(most_recent_window, window_size),
380+
(0, HEADS),
381+
(0, HIDDEN_DIM),
382+
) # sequence context dim
383+
self.global_v_caches["layer_idx"][i] = IREE.tensor_update(
384+
self.global_v_caches["layer_idx"][i],
385+
update_window_state,
386+
0,
387+
sink_size,
388+
0,
389+
0,
305390
)
306391
self.global_seq_step.set(window_size + sink_size)
307392
return self.global_seq_step

0 commit comments

Comments
 (0)