@@ -2933,3 +2933,240 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2933
2933
2934
2934
return true ;
2935
2935
}
2936
+
2937
+ //
2938
+ // llama_kv_cache_hybrid
2939
+ //
2940
+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (
2941
+ const llama_hparams & hparams,
2942
+ std::vector<child_cache> children) :
2943
+ m_hparams(hparams),
2944
+ m_children(
2945
+ [](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
2946
+ // Sort the caches by the lowest layer ID so the order is repeatable
2947
+ for (auto & cache : caches) {
2948
+ GGML_ASSERT (cache.layer_ids .size () > 0 );
2949
+ std::sort (cache.layer_ids .begin (), cache.layer_ids .end ());
2950
+ }
2951
+ std::sort (caches.begin (), caches.end (), [](const child_cache & a, const child_cache & b) {
2952
+ return a.layer_ids [0 ] < b.layer_ids [0 ];
2953
+ });
2954
+ std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2955
+ for (auto & cache : caches) {
2956
+ unique_caches.emplace (cache.child .release ());
2957
+ }
2958
+ return unique_caches;
2959
+ }(children)
2960
+ ),
2961
+ m_has_recurrent (
2962
+ [](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
2963
+ for (const auto & cache : caches) {
2964
+ if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
2965
+ return true ;
2966
+ }
2967
+ }
2968
+ return false ;
2969
+ }(m_children)
2970
+ )
2971
+ {
2972
+ // Ensure at least one child
2973
+ GGML_ASSERT (m_children.size () > 0 );
2974
+
2975
+ // Ensure layers are not overlapping and are concurrent
2976
+ std::set<size_t > seen_layers;
2977
+ size_t max_layer = 0 ;
2978
+ for (const auto & cache : children) {
2979
+ for (const auto & layer_id : cache.layer_ids ) {
2980
+ GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
2981
+ seen_layers.insert (layer_id);
2982
+ if (layer_id > max_layer) {
2983
+ max_layer = layer_id;
2984
+ }
2985
+ }
2986
+ }
2987
+ LLAMA_LOG_DEBUG (" max_layer=%zu, seen_layers.size()=%zu\n " , max_layer, seen_layers.size ());
2988
+ GGML_ASSERT (max_layer + 1 == seen_layers.size ());
2989
+ }
2990
+
2991
+ void llama_kv_cache_hybrid::clear () {
2992
+ for (const auto & cache : m_children) {
2993
+ cache->clear ();
2994
+ }
2995
+ }
2996
+
2997
+ bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2998
+ // First check if we can do this removal. This checks all children so that
2999
+ // no mutation happens before we know if it's possible
3000
+ if (!can_seq_rm (seq_id, p0, p1)) {
3001
+ return false ;
3002
+ }
3003
+
3004
+ // Do the removal from each child which should never fail
3005
+ for (const auto & cache : m_children) {
3006
+ const bool failed = cache->seq_rm (seq_id, p0, p1);
3007
+ GGML_ASSERT (!failed);
3008
+ }
3009
+ return true ;
3010
+ }
3011
+
3012
+ void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3013
+ for (const auto & cache : m_children) {
3014
+ cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3015
+ }
3016
+ }
3017
+
3018
+ void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
3019
+ for (const auto & cache : m_children) {
3020
+ cache->seq_keep (seq_id);
3021
+ }
3022
+ }
3023
+
3024
+ void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
3025
+ for (const auto & cache : m_children) {
3026
+ cache->seq_add (seq_id, p0, p1, delta);
3027
+ }
3028
+ }
3029
+
3030
+ void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3031
+ for (const auto & cache : m_children) {
3032
+ cache->seq_div (seq_id, p0, p1, d);
3033
+ }
3034
+ }
3035
+
3036
+ llama_pos llama_kv_cache_hybrid::seq_pos_min (llama_seq_id seq_id) const {
3037
+ llama_pos min_pos = -1 ;
3038
+ for (const auto & cache : m_children) {
3039
+ const auto child_min_pos = cache->seq_pos_min (seq_id);
3040
+ min_pos = min_pos == -1 ? child_min_pos : std::min (min_pos, child_min_pos);
3041
+ }
3042
+ return min_pos;
3043
+ }
3044
+
3045
+ llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
3046
+ llama_pos max_pos = 0 ;
3047
+ for (const auto & cache : m_children) {
3048
+ max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
3049
+ }
3050
+ return max_pos;
3051
+ }
3052
+
3053
+ void llama_kv_cache_hybrid::restore () {
3054
+ for (const auto & cache : m_children) {
3055
+ cache->restore ();
3056
+ }
3057
+ }
3058
+
3059
+ void llama_kv_cache_hybrid::commit () {
3060
+ for (const auto & cache : m_children) {
3061
+ cache->commit ();
3062
+ }
3063
+ }
3064
+
3065
+ bool llama_kv_cache_hybrid::update (llama_context & ctx) {
3066
+ bool updated = false ;
3067
+ for (const auto & cache : m_children) {
3068
+ updated = cache->update (ctx) || updated;
3069
+ }
3070
+ return updated;
3071
+ }
3072
+
3073
+ void llama_kv_cache_hybrid::defrag_sched (float thold) {
3074
+ for (const auto & cache : m_children) {
3075
+ cache->defrag_sched (thold);
3076
+ }
3077
+ }
3078
+
3079
+ void llama_kv_cache_hybrid::set_full () {
3080
+ for (const auto & cache : m_children) {
3081
+ cache->set_full ();
3082
+ }
3083
+ }
3084
+
3085
+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3086
+ for (const auto & cache : m_children) {
3087
+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
3088
+ return false ;
3089
+ }
3090
+ }
3091
+ return true ;
3092
+ }
3093
+
3094
+ llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
3095
+ // If any of the caches are recurrent, require equal split
3096
+ return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
3097
+ }
3098
+
3099
+ llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3100
+ if (embd_pooled) {
3101
+ // Pooled embeddings cannot be split across ubatches (yet)
3102
+ return sbatch.split_seq (n_ubatch);
3103
+ }
3104
+ if (m_has_recurrent) {
3105
+ return sbatch.split_equal (n_ubatch);
3106
+ }
3107
+ return sbatch.split_simple (n_ubatch);
3108
+ }
3109
+
3110
+ bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
3111
+ bool found = true ;
3112
+ for (const auto & cache : m_children) {
3113
+ found = cache->find_slot (batch) && found;
3114
+ }
3115
+ return found;
3116
+ }
3117
+
3118
+ int32_t llama_kv_cache_hybrid::get_n_tokens () const {
3119
+ // The number of tokens should be the same across all child caches
3120
+ int32_t n_tokens = -1 ;
3121
+ for (const auto & cache : m_children) {
3122
+ const auto cache_n_tokens = cache->get_n_tokens ();
3123
+ GGML_ASSERT (n_tokens == -1 || cache_n_tokens == n_tokens);
3124
+ n_tokens = cache_n_tokens;
3125
+ }
3126
+ return n_tokens;
3127
+ }
3128
+
3129
+ int32_t llama_kv_cache_hybrid::get_used_cells () const {
3130
+ // TODO: Is this correct?
3131
+ // Return the largest number of used cells
3132
+ int32_t used_cells = -1 ;
3133
+ for (const auto & cache : m_children) {
3134
+ used_cells = std::max (used_cells, cache->get_used_cells ());
3135
+ }
3136
+ return used_cells;
3137
+ }
3138
+
3139
+ llama_pos llama_kv_cache_hybrid::get_pos_max () const {
3140
+ llama_pos pos_max = -1 ;
3141
+ for (const auto & cache : m_children) {
3142
+ pos_max = std::max (pos_max, cache->get_pos_max ());
3143
+ }
3144
+ return pos_max;
3145
+ }
3146
+
3147
+ bool llama_kv_cache_hybrid::get_can_shift () const {
3148
+ // TODO: Is this correct?
3149
+ // If any children can shift, return true
3150
+ for (const auto & cache : m_children) {
3151
+ if (cache->get_can_shift ()) {
3152
+ return true ;
3153
+ }
3154
+ }
3155
+ return false ;
3156
+ }
3157
+
3158
+ void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3159
+ // Write each cache state in order. Note that order is guaranteed at
3160
+ // initialization by using an ordered set sorted by lowest layer ID
3161
+ for (const auto & cache : m_children) {
3162
+ cache->state_write (io, seq_id);
3163
+ }
3164
+ }
3165
+
3166
+ void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3167
+ // Read each cache state in order. Note that order is guaranteed at
3168
+ // initialization by using an ordered set sorted by lowest layer ID
3169
+ for (const auto & cache : m_children) {
3170
+ cache->state_read (io, seq_id);
3171
+ }
3172
+ }
0 commit comments