@@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
100
100
101
101
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update (llama_context * lctx, bool optimize) {
102
102
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
103
- this ,
104
103
static_cast <llama_kv_cache_unified_state *>( kv_attn ->init_update (lctx, optimize).release ()),
105
104
static_cast <llama_kv_cache_recurrent_state *>(kv_recurrent->init_update (lctx, optimize).release ()));
106
105
}
@@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla
179
178
180
179
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
181
180
: status(LLAMA_MEMORY_STATUS_SUCCESS),
182
- kv(kv),
183
181
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
184
182
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent ())) {}
185
183
186
184
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
187
- llama_kv_cache_hybrid_recurrent * kv,
188
185
llama_kv_cache_unified_state * state_unified,
189
186
llama_kv_cache_recurrent_state * state_recurrent)
190
187
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
191
- kv(kv),
192
188
state_attn(state_unified),
193
189
state_recurrent(state_recurrent) {}
194
190
@@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
198
194
std::vector<uint32_t > heads_attn,
199
195
std::vector<llama_ubatch> ubatches)
200
196
: status(LLAMA_MEMORY_STATUS_SUCCESS),
201
- kv(kv),
202
197
sbatch(std::move(sbatch)),
203
- heads_attn(std::move(heads_attn)),
204
198
ubatches(std::move(ubatches)),
205
- // NOTE: these child states are only used as wrapper APIs for the
206
- // const methods, so we use the "init full" signature since the
207
- // actual state is not used.
208
- state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn ())),
209
- state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ())) {}
199
+ // note: here we copy the ubatches. not sure if this is ideal
200
+ state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn (), {}, std::move(heads_attn), this->ubatches)),
201
+ state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent (), {}, this ->ubatches)) {}
210
202
211
203
212
204
bool llama_kv_cache_hybrid_recurrent_state::next () {
213
205
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
214
206
207
+ state_attn ->next ();
208
+ state_recurrent->next ();
209
+
215
210
if (++i_next >= ubatches.size ()) {
216
211
return false ;
217
212
}
@@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() {
222
217
bool llama_kv_cache_hybrid_recurrent_state::apply () {
223
218
assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
224
219
225
- kv->get_kv_attn () ->apply_ubatch (heads_attn[i_next], ubatches[i_next]);
226
- kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
220
+ bool res = true ;
227
221
228
- return true ;
222
+ res = res & state_attn ->apply ();
223
+ res = res & state_recurrent->apply ();
224
+
225
+ return res;
229
226
}
230
227
231
228
std::vector<int64_t > & llama_kv_cache_hybrid_recurrent_state::out_ids () {
0 commit comments