@@ -80,18 +80,30 @@ def generate_schema(num_layers):
80
80
return json .dumps (schema )
81
81
82
82
83
- def slice_up_to_step (global_pkv , seq_step , heads , hidden_dim , num_layers ):
83
+ def slice_up_to_step (k_caches , v_caches , seq_step , heads , hidden_dim , num_layers ):
84
84
all_pkv_tensors = []
85
85
for i in range (num_layers * 2 ):
86
86
# Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim]
87
87
# Generates tensor<1 x 1 x seq_step x heads x hidden_dim>
88
- sliced = IREE .tensor_slice (
89
- global_pkv , i , 0 , (0 , seq_step ), (0 , heads ), (0 , hidden_dim )
90
- ) # sequence context dim
88
+ if i % 2 == 0 :
89
+ sliced = IREE .tensor_slice (
90
+ k_caches ["layer_idx" ][i // 2 ],
91
+ 0 ,
92
+ (0 , seq_step ),
93
+ (0 , heads ),
94
+ (0 , hidden_dim ),
95
+ ) # sequence context dim
96
+ else :
97
+ sliced = IREE .tensor_slice (
98
+ v_caches ["layer_idx" ][i // 2 ],
99
+ 0 ,
100
+ (0 , seq_step ),
101
+ (0 , heads ),
102
+ (0 , hidden_dim ),
103
+ ) # sequence context dim
91
104
all_pkv_tensors .append (
92
105
IREE .tensor_reshape (sliced , 1 , seq_step , heads , hidden_dim )
93
106
)
94
-
95
107
return all_pkv_tensors
96
108
97
109
@@ -139,9 +151,12 @@ def export_transformer_model(
139
151
BATCH_SIZE = 1
140
152
MAX_STEP_SEQ = mod .config .max_position_embeddings - 1
141
153
global_pkv = torch .zeros (
142
- size = (NUM_LAYERS * 2 , BATCH_SIZE , MAX_STEP_SEQ , HEADS , HIDDEN_DIM ),
154
+ size = (BATCH_SIZE , MAX_STEP_SEQ , HEADS , HIDDEN_DIM ),
143
155
dtype = dtype ,
144
156
)
157
+ kv_cache_structure = {
158
+ "layer_idx" : [abstractify (global_pkv ) for _ in range (NUM_LAYERS )],
159
+ }
145
160
146
161
mapper = {}
147
162
if external_weights is not None :
@@ -163,29 +178,44 @@ class StateUpdateModule(CompiledModule):
163
178
)
164
179
else :
165
180
params = export_parameters (mod )
166
- global_state = export_global (
167
- abstractify (global_pkv ), uninitialized = True , mutable = True
168
- )
169
181
global_seq_step = export_global (AbstractIndex , mutable = True )
182
+ global_k_caches = export_global_tree (
183
+ kv_cache_structure , uninitialized = True , mutable = True
184
+ )
185
+ global_v_caches = export_global_tree (
186
+ kv_cache_structure , uninitialized = True , mutable = True
187
+ )
170
188
171
189
def run_initialize (self , x = AbstractTensor (BATCH_SIZE , None , dtype = torch .int64 )):
172
190
init_const = [x .dynamic_dim (1 ) < MAX_STEP_SEQ ]
173
191
token , * state = self .initialize (x , constraints = init_const )
174
192
self .global_seq_step = IREE .tensor_dim (
175
193
state [0 ], 1
176
194
) # ? dimension of arbitrarily 0th kv tensor
177
- for i in range (NUM_LAYERS * 2 ):
195
+ for i in range (NUM_LAYERS ):
178
196
slice_of_state = IREE .tensor_reshape (
179
- state [i ], 1 , 1 , self .global_seq_step , HEADS , HIDDEN_DIM
197
+ state [i * 2 ] , 1 , self .global_seq_step , HEADS , HIDDEN_DIM
180
198
)
181
- self .global_state = IREE .tensor_update (
182
- self .global_state , slice_of_state , i , 0 , 0 , 0 , 0
199
+ self .global_k_caches ["layer_idx" ][i ] = IREE .tensor_update (
200
+ self .global_k_caches ["layer_idx" ][i ], slice_of_state , 0 , 0 , 0 , 0
201
+ )
202
+ for i in range (NUM_LAYERS ):
203
+ slice_of_state = IREE .tensor_reshape (
204
+ state [i * 2 + 1 ], 1 , self .global_seq_step , HEADS , HIDDEN_DIM
205
+ )
206
+ self .global_v_caches ["layer_idx" ][i ] = IREE .tensor_update (
207
+ self .global_v_caches ["layer_idx" ][i ], slice_of_state , 0 , 0 , 0 , 0
183
208
)
184
209
return token
185
210
186
211
def run_forward (self , x = AbstractTensor (1 , 1 , dtype = torch .int64 )):
187
212
state_arg = slice_up_to_step (
188
- self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM , NUM_LAYERS
213
+ self .global_k_caches ,
214
+ self .global_v_caches ,
215
+ self .global_seq_step ,
216
+ HEADS ,
217
+ HIDDEN_DIM ,
218
+ NUM_LAYERS ,
189
219
)
190
220
forw_const = (
191
221
[state_arg [0 ].dynamic_dim (1 ) < MAX_STEP_SEQ ]
@@ -196,20 +226,33 @@ def run_forward(self, x=AbstractTensor(1, 1, dtype=torch.int64)):
196
226
+ [x .dynamic_dim (1 ) < MAX_STEP_SEQ for x in state_arg [1 :]]
197
227
)
198
228
token , * state_update = self .forward (x , * state_arg , constraints = forw_const )
199
- for i in range (NUM_LAYERS * 2 ):
229
+ for i in range (NUM_LAYERS ):
200
230
update = IREE .tensor_reshape (
201
- state_update [i ], 1 , 1 , 1 , HEADS , HIDDEN_DIM
231
+ state_update [i * 2 ] , 1 , 1 , HEADS , HIDDEN_DIM
202
232
)
203
- self .global_state = IREE .tensor_update (
204
- self .global_state , update , i , 0 , self .global_seq_step , 0 , 0
233
+ self .global_k_caches ["layer_idx" ][i ] = IREE .tensor_update (
234
+ self .global_k_caches ["layer_idx" ][i ],
235
+ update ,
236
+ 0 ,
237
+ self .global_seq_step ,
238
+ 0 ,
239
+ 0 ,
240
+ )
241
+ for i in range (NUM_LAYERS ):
242
+ update = IREE .tensor_reshape (
243
+ state_update [i * 2 + 1 ], 1 , 1 , HEADS , HIDDEN_DIM
244
+ )
245
+ self .global_v_caches ["layer_idx" ][i ] = IREE .tensor_update (
246
+ self .global_v_caches ["layer_idx" ][i ],
247
+ update ,
248
+ 0 ,
249
+ self .global_seq_step ,
250
+ 0 ,
251
+ 0 ,
205
252
)
206
-
207
253
self .global_seq_step = self .global_seq_step + 1
208
254
return token
209
255
210
- def get_global_state (self ):
211
- return self .global_state
212
-
213
256
def get_seq_step (self ):
214
257
return self .global_seq_step
215
258
@@ -239,7 +282,12 @@ def run_cached_initialize(
239
282
self , x = AbstractTensor (BATCH_SIZE , None , dtype = torch .int64 )
240
283
):
241
284
state_arg = slice_up_to_step (
242
- self .global_state , self .global_seq_step , HEADS , HIDDEN_DIM , NUM_LAYERS
285
+ self .global_k_caches ,
286
+ self .global_v_caches ,
287
+ self .global_seq_step ,
288
+ HEADS ,
289
+ HIDDEN_DIM ,
290
+ NUM_LAYERS ,
243
291
)
244
292
forw_const = (
245
293
[x .dynamic_dim (1 ) < MAX_STEP_SEQ ]
@@ -256,12 +304,29 @@ def run_cached_initialize(
256
304
len_of_new_tokens = IREE .tensor_dim (
257
305
state [0 ], 1
258
306
) # ? dimension of arbitrarily 0th kv tensor
259
- for i in range (NUM_LAYERS * 2 ):
307
+ for i in range (NUM_LAYERS ):
308
+ slice_of_state = IREE .tensor_reshape (
309
+ state [i * 2 ], 1 , len_of_new_tokens , HEADS , HIDDEN_DIM
310
+ )
311
+ self .global_k_caches ["layer_idx" ][i ] = IREE .tensor_update (
312
+ self .global_k_caches ["layer_idx" ][i ],
313
+ slice_of_state ,
314
+ 0 ,
315
+ self .global_seq_step ,
316
+ 0 ,
317
+ 0 ,
318
+ )
319
+ for i in range (NUM_LAYERS ):
260
320
slice_of_state = IREE .tensor_reshape (
261
- state [i ], 1 , 1 , len_of_new_tokens , HEADS , HIDDEN_DIM
321
+ state [i * 2 + 1 ] , 1 , len_of_new_tokens , HEADS , HIDDEN_DIM
262
322
)
263
- self .global_state = IREE .tensor_update (
264
- self .global_state , slice_of_state , i , 0 , self .global_seq_step , 0 , 0
323
+ self .global_v_caches ["layer_idx" ][i ] = IREE .tensor_update (
324
+ self .global_v_caches ["layer_idx" ][i ],
325
+ slice_of_state ,
326
+ 0 ,
327
+ self .global_seq_step ,
328
+ 0 ,
329
+ 0 ,
265
330
)
266
331
self .global_seq_step = self .global_seq_step + len_of_new_tokens
267
332
return token
@@ -291,17 +356,37 @@ def evict_kvcache_space(self):
291
356
sink_size = 4
292
357
window_size = 252
293
358
most_recent_window = self .global_seq_step + (- window_size )
294
- for i in range (NUM_LAYERS * 2 ):
359
+ for i in range (NUM_LAYERS ):
295
360
update_window_state = IREE .tensor_slice (
296
- self .global_state ,
297
- i ,
361
+ self .global_k_caches ["layer_idx" ][i ],
298
362
0 ,
299
363
(most_recent_window , window_size ),
300
364
(0 , HEADS ),
301
365
(0 , HIDDEN_DIM ),
302
366
) # sequence context dim
303
- self .global_state = IREE .tensor_update (
304
- self .global_state , update_window_state , i , 0 , sink_size , 0 , 0
367
+ self .global_k_caches ["layer_idx" ][i ] = IREE .tensor_update (
368
+ self .global_k_caches ["layer_idx" ][i ],
369
+ update_window_state ,
370
+ 0 ,
371
+ sink_size ,
372
+ 0 ,
373
+ 0 ,
374
+ )
375
+ for i in range (NUM_LAYERS ):
376
+ update_window_state = IREE .tensor_slice (
377
+ self .global_v_caches ["layer_idx" ][i ],
378
+ 0 ,
379
+ (most_recent_window , window_size ),
380
+ (0 , HEADS ),
381
+ (0 , HIDDEN_DIM ),
382
+ ) # sequence context dim
383
+ self .global_v_caches ["layer_idx" ][i ] = IREE .tensor_update (
384
+ self .global_v_caches ["layer_idx" ][i ],
385
+ update_window_state ,
386
+ 0 ,
387
+ sink_size ,
388
+ 0 ,
389
+ 0 ,
305
390
)
306
391
self .global_seq_step .set (window_size + sink_size )
307
392
return self .global_seq_step
0 commit comments