Skip to content

Commit bd2f795

Browse files
authored
Config object for StaticAttentionIOManager and decoding helper
Differential Revision: D78113341 Pull Request resolved: #12379
1 parent 108b685 commit bd2f795

File tree

1 file changed

+85
-42
lines changed

1 file changed

+85
-42
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ class StaticKVCache {
168168
}
169169
}
170170

171+
size_t size() {
172+
return valid_len_;
173+
}
174+
171175
private:
172176
void init_ptrs() {
173177
input_ptrs_.resize(n_caches_);
@@ -339,26 +343,40 @@ template <
339343
typename MaskAllocatorT = std::allocator<MaskT>>
340344
class StaticAttentionIOManager {
341345
public:
342-
StaticAttentionIOManager(
343-
size_t n_caches,
344-
size_t cache_len,
345-
size_t head_dim,
346-
size_t max_input_len,
347-
size_t rope_freqs_cos_index,
348-
size_t rope_freqs_sin_index,
349-
RopeT* rope_freqs_cos,
350-
RopeT* rope_freqs_sin,
351-
StaticAttentionUpdateStyle style =
352-
StaticAttentionUpdateStyle::SLIDING_CACHE)
353-
: cache_len_(cache_len),
354-
head_dim_(head_dim),
355-
style_(style),
356-
kCaches_(n_caches, cache_len, head_dim, max_input_len, false, style),
357-
vCaches_(n_caches, cache_len, head_dim, max_input_len, false, style),
358-
rope_freqs_cos_index_(rope_freqs_cos_index),
359-
rope_freqs_sin_index_(rope_freqs_sin_index),
360-
rope_freqs_cos_(rope_freqs_cos),
361-
rope_freqs_sin_(rope_freqs_sin) {}
346+
struct StaticAttentionIOConfig {
347+
size_t n_caches{};
348+
size_t cache_len{};
349+
size_t head_dim{};
350+
size_t max_input_len{};
351+
size_t attn_mask_input_index{};
352+
size_t rope_freqs_cos_input_index{};
353+
size_t rope_freqs_sin_input_index{};
354+
std::vector<size_t> k_cache_input_indices;
355+
std::vector<size_t> k_cache_output_indices;
356+
std::vector<size_t> v_cache_input_indices;
357+
std::vector<size_t> v_cache_output_indices;
358+
RopeT* rope_freqs_cos;
359+
RopeT* rope_freqs_sin;
360+
StaticAttentionUpdateStyle style =
361+
StaticAttentionUpdateStyle::SLIDING_CACHE;
362+
};
363+
364+
StaticAttentionIOManager(StaticAttentionIOConfig config)
365+
: config_(std::move(config)),
366+
kCaches_(
367+
config_.n_caches,
368+
config_.cache_len,
369+
config_.head_dim,
370+
config_.max_input_len,
371+
false,
372+
config_.style),
373+
vCaches_(
374+
config_.n_caches,
375+
config_.cache_len,
376+
config_.head_dim,
377+
config_.max_input_len,
378+
false,
379+
config_.style) {}
362380

363381
/**
364382
* Create a new StaticAttentionMask that will be managed by this object.
@@ -369,36 +387,38 @@ class StaticAttentionIOManager {
369387
std::piecewise_construct,
370388
std::forward_as_tuple(input_len),
371389
std::forward_as_tuple(
372-
cache_len_, input_len, head_dim_, zero_val, mask_val, style_));
390+
config_.cache_len,
391+
input_len,
392+
config_.head_dim,
393+
zero_val,
394+
mask_val,
395+
config_.style));
373396
return it.first->second;
374397
}
375398

376399
/**
377400
* Retrieve a mask suitable for given input length.
378401
*/
379-
StaticAttentionMask<MaskT, MaskAllocatorT>& getMask(size_t input_len) {
402+
StaticAttentionMask<MaskT, MaskAllocatorT>& get_mask(size_t input_len) {
380403
return attentionMasks_.at(input_len);
381404
}
382405

383406
/**
384407
* Set I/O pointers for KV cache and RoPE freqencies.
385408
*/
386-
void prepare(
387-
torch::executor::Method& method,
388-
const std::vector<size_t>& k_cache_input_indices,
389-
const std::vector<size_t>& k_cache_output_indices,
390-
const std::vector<size_t>& v_cache_input_indices,
391-
const std::vector<size_t>& v_cache_output_indices) {
392-
kCaches_.prepare(method, k_cache_input_indices, k_cache_output_indices);
393-
vCaches_.prepare(method, v_cache_input_indices, v_cache_output_indices);
409+
void prepare(torch::executor::Method& method) {
410+
kCaches_.prepare(
411+
method, config_.k_cache_input_indices, config_.k_cache_output_indices);
412+
vCaches_.prepare(
413+
method, config_.v_cache_input_indices, config_.v_cache_output_indices);
394414
set_input(
395415
method,
396-
rope_freqs_cos_index_,
397-
rope_freqs_cos_ + input_pos_ * head_dim_ / 2);
416+
config_.rope_freqs_cos_input_index,
417+
config_.rope_freqs_cos + input_pos_ * config_.head_dim / 2);
398418
set_input(
399419
method,
400-
rope_freqs_sin_index_,
401-
rope_freqs_sin_ + input_pos_ * head_dim_ / 2);
420+
config_.rope_freqs_sin_input_index,
421+
config_.rope_freqs_sin + input_pos_ * config_.head_dim / 2);
402422
}
403423

404424
/**
@@ -430,6 +450,36 @@ class StaticAttentionIOManager {
430450
}
431451
}
432452

453+
template <typename TokenT>
454+
std::vector<TokenT> decode(
455+
TokenT prev_tok,
456+
executorch::runtime::Span<TokenT> input_buffer,
457+
executorch::runtime::Method& method,
458+
std::function<TokenT(executorch::runtime::Method&)>& sample,
459+
std::function<bool(TokenT)>& should_stop) {
460+
set_input(method, 0, input_buffer.data());
461+
auto& mask = get_mask(input_buffer.size());
462+
set_input(method, config_.attn_mask_input_index, mask.get());
463+
464+
std::vector<TokenT> generated_tokens;
465+
while (kCaches_.size() + 1 <= config_.cache_len) {
466+
input_buffer[0] = prev_tok;
467+
prepare(method);
468+
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
469+
update(
470+
method,
471+
config_.k_cache_output_indices,
472+
config_.v_cache_output_indices,
473+
1);
474+
prev_tok = sample(method);
475+
generated_tokens.emplace_back(prev_tok);
476+
if (should_stop(prev_tok)) {
477+
break;
478+
}
479+
}
480+
return generated_tokens;
481+
}
482+
433483
private:
434484
template <typename T>
435485
void set_input(executorch::runtime::Method& method, size_t idx, T* data) {
@@ -447,19 +497,12 @@ class StaticAttentionIOManager {
447497
ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok);
448498
}
449499

450-
size_t cache_len_;
451-
size_t input_len_;
452-
size_t head_dim_;
500+
StaticAttentionIOConfig config_;
453501
size_t input_pos_;
454-
StaticAttentionUpdateStyle style_;
455502
StaticKVCache<CacheT, CacheAllocatorT> kCaches_;
456503
StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
457504
std::unordered_map<size_t, StaticAttentionMask<MaskT, MaskAllocatorT>>
458505
attentionMasks_;
459-
size_t rope_freqs_cos_index_;
460-
size_t rope_freqs_sin_index_;
461-
RopeT* rope_freqs_cos_;
462-
RopeT* rope_freqs_sin_;
463506
};
464507

465508
} // namespace example

0 commit comments

Comments
 (0)