Skip to content

Commit 8d24073

Browse files
committed
fix: Fix initialization of child states
Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent beddd62 commit 8d24073

File tree

2 files changed

+11
-18
lines changed

2 files changed

+11
-18
lines changed

src/llama-kv-cache-hybrid-recurrent.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
100100

101101
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
102102
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
103-
this,
104103
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
105104
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
106105
}
@@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla
179178

180179
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
181180
: status(LLAMA_MEMORY_STATUS_SUCCESS),
182-
kv(kv),
183181
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
184182
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
185183

186184
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
187-
llama_kv_cache_hybrid_recurrent * kv,
188185
llama_kv_cache_unified_state * state_unified,
189186
llama_kv_cache_recurrent_state * state_recurrent)
190187
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
191-
kv(kv),
192188
state_attn(state_unified),
193189
state_recurrent(state_recurrent) {}
194190

@@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
198194
std::vector<uint32_t> heads_attn,
199195
std::vector<llama_ubatch> ubatches)
200196
: status(LLAMA_MEMORY_STATUS_SUCCESS),
201-
kv(kv),
202197
sbatch(std::move(sbatch)),
203-
heads_attn(std::move(heads_attn)),
204198
ubatches(std::move(ubatches)),
205-
// NOTE: these child states are only used as wrapper APIs for the
206-
// const methods, so we use the "init full" signature since the
207-
// actual state is not used.
208-
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
209-
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {}
199+
// note: here we copy the ubatches. not sure if this is ideal
200+
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
201+
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
210202

211203

212204
bool llama_kv_cache_hybrid_recurrent_state::next() {
213205
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
214206

207+
state_attn ->next();
208+
state_recurrent->next();
209+
215210
if (++i_next >= ubatches.size()) {
216211
return false;
217212
}
@@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() {
222217
bool llama_kv_cache_hybrid_recurrent_state::apply() {
223218
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
224219

225-
kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]);
226-
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
220+
bool res = true;
227221

228-
return true;
222+
res = res & state_attn ->apply();
223+
res = res & state_recurrent->apply();
224+
225+
return res;
229226
}
230227

231228
std::vector<int64_t> & llama_kv_cache_hybrid_recurrent_state::out_ids() {

src/llama-kv-cache-hybrid-recurrent.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
104104

105105
// init update
106106
explicit llama_kv_cache_hybrid_recurrent_state(
107-
llama_kv_cache_hybrid_recurrent * kv,
108107
llama_kv_cache_unified_state * state_unified,
109108
llama_kv_cache_recurrent_state * state_recurrent);
110109

@@ -135,14 +134,11 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
135134
private:
136135
const llama_memory_status status;
137136

138-
llama_kv_cache_hybrid_recurrent * kv;
139-
140137
llama_sbatch sbatch;
141138

142139
// the index of the next ubatch to process
143140
size_t i_next = 0;
144141

145-
std::vector<uint32_t> heads_attn;
146142
std::vector<llama_ubatch> ubatches;
147143

148144
const llama_memory_state_ptr state_attn;

0 commit comments

Comments
 (0)