@@ -168,6 +168,10 @@ class StaticKVCache {
168
168
}
169
169
}
170
170
171
+ size_t size () {
172
+ return valid_len_;
173
+ }
174
+
171
175
private:
172
176
void init_ptrs () {
173
177
input_ptrs_.resize (n_caches_);
@@ -339,26 +343,40 @@ template <
339
343
typename MaskAllocatorT = std::allocator<MaskT>>
340
344
class StaticAttentionIOManager {
341
345
public:
342
- StaticAttentionIOManager (
343
- size_t n_caches,
344
- size_t cache_len,
345
- size_t head_dim,
346
- size_t max_input_len,
347
- size_t rope_freqs_cos_index,
348
- size_t rope_freqs_sin_index,
349
- RopeT* rope_freqs_cos,
350
- RopeT* rope_freqs_sin,
351
- StaticAttentionUpdateStyle style =
352
- StaticAttentionUpdateStyle::SLIDING_CACHE)
353
- : cache_len_(cache_len),
354
- head_dim_ (head_dim),
355
- style_(style),
356
- kCaches_(n_caches, cache_len, head_dim, max_input_len, false , style),
357
- vCaches_(n_caches, cache_len, head_dim, max_input_len, false , style),
358
- rope_freqs_cos_index_(rope_freqs_cos_index),
359
- rope_freqs_sin_index_(rope_freqs_sin_index),
360
- rope_freqs_cos_(rope_freqs_cos),
361
- rope_freqs_sin_(rope_freqs_sin) {}
346
+ struct StaticAttentionIOConfig {
347
+ size_t n_caches{};
348
+ size_t cache_len{};
349
+ size_t head_dim{};
350
+ size_t max_input_len{};
351
+ size_t attn_mask_input_index{};
352
+ size_t rope_freqs_cos_input_index{};
353
+ size_t rope_freqs_sin_input_index{};
354
+ std::vector<size_t > k_cache_input_indices;
355
+ std::vector<size_t > k_cache_output_indices;
356
+ std::vector<size_t > v_cache_input_indices;
357
+ std::vector<size_t > v_cache_output_indices;
358
+ RopeT* rope_freqs_cos;
359
+ RopeT* rope_freqs_sin;
360
+ StaticAttentionUpdateStyle style =
361
+ StaticAttentionUpdateStyle::SLIDING_CACHE;
362
+ };
363
+
364
+ StaticAttentionIOManager (StaticAttentionIOConfig config)
365
+ : config_(std::move(config)),
366
+ kCaches_ (
367
+ config_.n_caches,
368
+ config_.cache_len,
369
+ config_.head_dim,
370
+ config_.max_input_len,
371
+ false ,
372
+ config_.style),
373
+ vCaches_(
374
+ config_.n_caches,
375
+ config_.cache_len,
376
+ config_.head_dim,
377
+ config_.max_input_len,
378
+ false ,
379
+ config_.style) {}
362
380
363
381
/* *
364
382
* Create a new StaticAttentionMask that will be managed by this object.
@@ -369,36 +387,38 @@ class StaticAttentionIOManager {
369
387
std::piecewise_construct,
370
388
std::forward_as_tuple (input_len),
371
389
std::forward_as_tuple (
372
- cache_len_, input_len, head_dim_, zero_val, mask_val, style_));
390
+ config_.cache_len ,
391
+ input_len,
392
+ config_.head_dim ,
393
+ zero_val,
394
+ mask_val,
395
+ config_.style ));
373
396
return it.first ->second ;
374
397
}
375
398
376
399
/* *
377
400
* Retrieve a mask suitable for given input length.
378
401
*/
379
- StaticAttentionMask<MaskT, MaskAllocatorT>& getMask (size_t input_len) {
402
+ StaticAttentionMask<MaskT, MaskAllocatorT>& get_mask (size_t input_len) {
380
403
return attentionMasks_.at (input_len);
381
404
}
382
405
383
406
/* *
384
407
* Set I/O pointers for KV cache and RoPE freqencies.
385
408
*/
386
- void prepare (
387
- torch::executor::Method& method,
388
- const std::vector<size_t >& k_cache_input_indices,
389
- const std::vector<size_t >& k_cache_output_indices,
390
- const std::vector<size_t >& v_cache_input_indices,
391
- const std::vector<size_t >& v_cache_output_indices) {
392
- kCaches_ .prepare (method, k_cache_input_indices, k_cache_output_indices);
393
- vCaches_.prepare (method, v_cache_input_indices, v_cache_output_indices);
409
+ void prepare (torch::executor::Method& method) {
410
+ kCaches_ .prepare (
411
+ method, config_.k_cache_input_indices , config_.k_cache_output_indices );
412
+ vCaches_.prepare (
413
+ method, config_.v_cache_input_indices , config_.v_cache_output_indices );
394
414
set_input (
395
415
method,
396
- rope_freqs_cos_index_ ,
397
- rope_freqs_cos_ + input_pos_ * head_dim_ / 2 );
416
+ config_. rope_freqs_cos_input_index ,
417
+ config_. rope_freqs_cos + input_pos_ * config_. head_dim / 2 );
398
418
set_input (
399
419
method,
400
- rope_freqs_sin_index_ ,
401
- rope_freqs_sin_ + input_pos_ * head_dim_ / 2 );
420
+ config_. rope_freqs_sin_input_index ,
421
+ config_. rope_freqs_sin + input_pos_ * config_. head_dim / 2 );
402
422
}
403
423
404
424
/* *
@@ -430,6 +450,36 @@ class StaticAttentionIOManager {
430
450
}
431
451
}
432
452
453
+ template <typename TokenT>
454
+ std::vector<TokenT> decode (
455
+ TokenT prev_tok,
456
+ executorch::runtime::Span<TokenT> input_buffer,
457
+ executorch::runtime::Method& method,
458
+ std::function<TokenT(executorch::runtime::Method&)>& sample,
459
+ std::function<bool(TokenT)>& should_stop) {
460
+ set_input (method, 0 , input_buffer.data ());
461
+ auto & mask = get_mask (input_buffer.size ());
462
+ set_input (method, config_.attn_mask_input_index , mask.get ());
463
+
464
+ std::vector<TokenT> generated_tokens;
465
+ while (kCaches_ .size () + 1 <= config_.cache_len ) {
466
+ input_buffer[0 ] = prev_tok;
467
+ prepare (method);
468
+ ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
469
+ update (
470
+ method,
471
+ config_.k_cache_output_indices ,
472
+ config_.v_cache_output_indices ,
473
+ 1 );
474
+ prev_tok = sample (method);
475
+ generated_tokens.emplace_back (prev_tok);
476
+ if (should_stop (prev_tok)) {
477
+ break ;
478
+ }
479
+ }
480
+ return generated_tokens;
481
+ }
482
+
433
483
private:
434
484
template <typename T>
435
485
void set_input (executorch::runtime::Method& method, size_t idx, T* data) {
@@ -447,19 +497,12 @@ class StaticAttentionIOManager {
447
497
ET_CHECK (method.set_input (t, idx) == executorch::runtime::Error::Ok);
448
498
}
449
499
450
- size_t cache_len_;
451
- size_t input_len_;
452
- size_t head_dim_;
500
+ StaticAttentionIOConfig config_;
453
501
size_t input_pos_;
454
- StaticAttentionUpdateStyle style_;
455
502
StaticKVCache<CacheT, CacheAllocatorT> kCaches_ ;
456
503
StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
457
504
std::unordered_map<size_t , StaticAttentionMask<MaskT, MaskAllocatorT>>
458
505
attentionMasks_;
459
- size_t rope_freqs_cos_index_;
460
- size_t rope_freqs_sin_index_;
461
- RopeT* rope_freqs_cos_;
462
- RopeT* rope_freqs_sin_;
463
506
};
464
507
465
508
} // namespace example
0 commit comments