@@ -2583,7 +2583,7 @@ struct llama_hparams {
2583
2583
return n_embd_head_v * n_head_kv;
2584
2584
}
2585
2585
2586
- uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
2586
+ uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings
2587
2587
// TODO: support using an SSM in place of the MLP of a Transformer
2588
2588
if (n_head_kv(il) != 0) { return 0; }
2589
2589
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
@@ -2597,7 +2597,7 @@ struct llama_hparams {
2597
2597
}
2598
2598
}
2599
2599
2600
- uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
2600
+ uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings
2601
2601
// TODO: support using an SSM in place of the MLP of a Transformer
2602
2602
if (n_head_kv(il) != 0) { return 0; }
2603
2603
@@ -2875,17 +2875,13 @@ struct llama_kv_self_cache {
2875
2875
2876
2876
struct llama_rs_cell {
2877
2877
llama_pos pos = -1;
2878
- int32_t src = -1; // copy source id (cleared next when -1)
2878
+ int32_t src = -1; // copy source id (cleared next when -1)
2879
2879
2880
2880
std::set<llama_seq_id> seq_id;
2881
2881
2882
- bool has_seq_id(const llama_seq_id & id) const {
2883
- return seq_id.find(id) != seq_id.end();
2884
- }
2882
+ bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); }
2885
2883
2886
- bool is_empty() const {
2887
- return seq_id.empty();
2888
- }
2884
+ bool is_empty() const { return seq_id.empty(); }
2889
2885
};
2890
2886
2891
2887
struct llama_rs_seq_meta {
@@ -2895,46 +2891,45 @@ struct llama_rs_seq_meta {
2895
2891
2896
2892
// ring-buffered tree of cached recurrent state data
2897
2893
struct llama_rs_self_cache {
2898
-
2899
- uint32_t head = 0; // first state used for the last slot
2894
+ uint32_t head = 0; // first state used for the last slot
2900
2895
uint32_t size = 0;
2901
2896
uint32_t used = 0;
2902
2897
2903
2898
// computed when finding a slot
2904
- uint32_t n = 0; // range of states used for the last slot
2899
+ uint32_t n = 0; // range of states used for the last slot
2905
2900
2906
2901
// with state models, a cell can hold the state for more than one past token
2907
2902
std::vector<llama_rs_cell> cells;
2908
2903
2909
2904
// find tail cells faster
2910
- std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
2905
+ std::vector<llama_rs_seq_meta> seq_tails; // map seq_ids to cell ids
2911
2906
2912
2907
// per layer
2913
2908
// NOTE: the naming of r and s is arbitrary
2914
- std::vector<struct ggml_tensor *> r_l; // rolling/shift states
2915
- std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
2909
+ std::vector<struct ggml_tensor *> r_l; // rolling/shift states
2910
+ std::vector<struct ggml_tensor *> s_l; // ssm (recurrent) states
2916
2911
2917
2912
// Inefficient, but thorough verification and rebuilding of the rs cache
2918
2913
// from only the cells list with `pos` and seq_ids.
2919
2914
// Should not be called in a hot loop except when desperate and/or debugging.
2920
2915
bool rebuild(bool debug) {
2921
2916
bool was_valid = true;
2922
2917
// skip for non-recurrent models
2923
- if (size == 0) { return true; }
2918
+ if (size == 0) {
2919
+ return true;
2920
+ }
2924
2921
// the source of truth is the cells list
2925
2922
// buffer sizes
2926
2923
if (size != cells.size()) {
2927
2924
if (debug) {
2928
- LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n",
2929
- __func__, cells.size(), size);
2925
+ LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size);
2930
2926
}
2931
2927
cells.resize(size);
2932
2928
was_valid = false;
2933
2929
}
2934
2930
if (size != seq_tails.size()) {
2935
2931
if (debug) {
2936
- LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n",
2937
- __func__, seq_tails.size(), size);
2932
+ LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size);
2938
2933
}
2939
2934
seq_tails.resize(size);
2940
2935
was_valid = false;
@@ -2994,7 +2989,7 @@ struct llama_rs_self_cache {
2994
2989
for (uint32_t cell_id = 0; cell_id < size; ++cell_id) {
2995
2990
llama_rs_cell & cell = cells[cell_id];
2996
2991
if (cell.has_seq_id(seq_id)) {
2997
- seq_cells.push_back({cell.pos, cell_id});
2992
+ seq_cells.push_back({ cell.pos, cell_id });
2998
2993
}
2999
2994
}
3000
2995
// sort by pos and then by cell_id
@@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init(
3718
3713
}
3719
3714
3720
3715
if (has_kv) {
3721
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)* kv_size);
3722
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)* kv_size);
3716
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size);
3717
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size);
3723
3718
ggml_format_name(k, "cache_k_l%d", i);
3724
3719
ggml_format_name(v, "cache_v_l%d", i);
3725
3720
cache.kv.k_l.push_back(k);
3726
3721
cache.kv.v_l.push_back(v);
3727
3722
}
3728
3723
if (has_rs) {
3729
- ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)* rs_size);
3730
- ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)* rs_size);
3724
+ ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size);
3725
+ ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size);
3731
3726
ggml_format_name(r, "cache_r_l%d", i);
3732
3727
ggml_format_name(s, "cache_s_l%d", i);
3733
3728
cache.rs.r_l.push_back(r);
@@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer {
4370
4365
bool do_restore = false;
4371
4366
4372
4367
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4373
- old_state.head = cache.kv.head;
4374
- old_state.n = cache.kv.n;
4368
+ old_state.head = cache.kv.head;
4369
+ old_state.n = cache.kv.n;
4375
4370
}
4376
4371
4377
4372
// saves a slot information for future restoration
@@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer {
4388
4383
// and rollback changes from all llama_kv_cache_find_slot calls
4389
4384
void restore(struct llama_kv_cache & cache) {
4390
4385
if (do_restore) {
4391
- cache.kv.head = old_state.head;
4392
- cache.kv.n = old_state.n;
4386
+ cache.kv.head = old_state.head;
4387
+ cache.kv.n = old_state.n;
4393
4388
4394
- if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
4389
+ if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased
4395
4390
llama_kv_cache_seq_rm(cache, -1, -1, -1);
4396
4391
} else {
4397
4392
for (auto & slot : slot_boundaries) {
0 commit comments