Skip to content

Commit e6ac4ac

Browse files
committed
batch : rename batch_allocr to balloc
ggml-ci
1 parent 1f6a916 commit e6ac4ac

10 files changed

+39
-38
lines changed

src/llama-context.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ llama_context::llama_context(
2020
const llama_model & model,
2121
llama_context_params params) :
2222
model(model),
23-
batch_allocr(std::make_unique<llama_batch_allocr>()) {
23+
balloc(std::make_unique<llama_batch_allocr>()) {
2424
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2525

2626
t_start_us = model.t_start_us;
@@ -734,14 +734,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
734734
const int64_t n_embd = hparams.n_embd;
735735

736736
// note: during encode, we always pass the full sequence starting from pos = 0
737-
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
737+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738738
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739739
return -1;
740740
}
741741

742-
const uint32_t n_tokens = batch_allocr->get_n_tokens();
742+
const uint32_t n_tokens = balloc->get_n_tokens();
743743

744-
const llama_ubatch ubatch = batch_allocr->split_simple(n_tokens);
744+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745745

746746
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
747747
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -859,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
859859
cross.v_embd.resize(cross.n_embd*cross.n_enc);
860860
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
861861

862-
const auto & batch = batch_allocr->get_batch();
862+
const auto & batch = balloc->get_batch();
863863

864864
// remember the sequence ids used during the encoding - needed for cross attention later
865865
cross.seq_ids_enc.resize(n_tokens);
@@ -897,13 +897,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
897897
// when computing embeddings, all tokens are output
898898
const bool output_all = cparams.embeddings;
899899

900-
if (!batch_allocr->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
900+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
901901
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
902902
return -1;
903903
}
904904

905-
const uint32_t n_tokens_all = batch_allocr->get_n_tokens();
906-
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
905+
const uint32_t n_tokens_all = balloc->get_n_tokens();
906+
const uint32_t n_outputs_all = balloc->get_n_outputs();
907907

908908
if (output_all) {
909909
// require that all tokens are output
@@ -934,7 +934,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
934934
llama_memory_state_ptr mstate;
935935

936936
while (true) {
937-
mstate = memory->init_batch(batch_allocr.get(), cparams.n_ubatch, output_all);
937+
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
938938
if (!mstate) {
939939
return -2;
940940
}
@@ -955,19 +955,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
955955
did_optimize = true;
956956

957957
if (kv_self_update(true)) {
958-
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch_allocr->get_n_tokens());
958+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
959959

960960
continue;
961961
}
962962
}
963963

964-
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch_allocr->get_n_tokens());
964+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
965965

966966
return 1;
967967
}
968968
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
969969
{
970-
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch_allocr->get_n_tokens());
970+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
971971

972972
return -2;
973973
}
@@ -1133,7 +1133,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11331133
if (n_outputs > 0) {
11341134
bool sorted_output = true;
11351135

1136-
auto & out_ids = batch_allocr->get_out_ids();
1136+
auto & out_ids = balloc->get_out_ids();
11371137

11381138
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
11391139

@@ -1306,8 +1306,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13061306

13071307
this->n_outputs = n_outputs;
13081308

1309-
llama_batch_allocr batch_allocr;
1310-
llama_ubatch ubatch = batch_allocr.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1309+
llama_batch_allocr balloc;
1310+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13111311

13121312
auto * gf = graph_init();
13131313
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
@@ -2027,12 +2027,12 @@ void llama_context::opt_epoch_iter(
20272027
batch.logits [pos_batch] = true;
20282028
}
20292029

2030-
if (!batch_allocr->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2030+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
20312031
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
20322032
return;
20332033
}
20342034

2035-
const uint32_t n_tokens_all = batch_allocr->get_n_tokens();
2035+
const uint32_t n_tokens_all = balloc->get_n_tokens();
20362036

20372037
n_queued_tokens += n_tokens_all;
20382038

@@ -2041,7 +2041,7 @@ void llama_context::opt_epoch_iter(
20412041
uint32_t n_outputs_all = n_tokens_all;
20422042

20432043
// TODO: fix
2044-
auto mstate = memory->init_batch(batch_allocr.get(), cparams.n_ubatch, true);
2044+
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
20452045
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20462046
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20472047
break;

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ struct llama_context {
247247
std::map<llama_seq_id, std::vector<float>> embd_seq;
248248

249249
// reuse the batch_allocr to avoid unnecessary memory allocations
250-
std::unique_ptr<llama_batch_allocr> batch_allocr;
250+
std::unique_ptr<llama_batch_allocr> balloc;
251251

252252
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
253253

src/llama-graph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
252252
}
253253

254254
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
255+
// TODO: repace this if with GGML_ASSERT(kq_mask)
255256
if (kq_mask) {
256257
if (cparams.causal_attn) {
257258
const int64_t n_kv = ubatch->n_tokens;

src/llama-kv-cache-recurrent.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,17 +359,17 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
359359
return result;
360360
}
361361

362-
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(llama_batch_allocr * batch_allocr, uint32_t n_ubatch, bool embd_all) {
362+
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
363363
std::vector<llama_ubatch> ubatches;
364364

365365
while (true) {
366366
llama_ubatch ubatch;
367367

368368
if (embd_all) {
369369
// if all tokens are output, split by sequence
370-
ubatch = batch_allocr->split_seq(n_ubatch);
370+
ubatch = balloc.split_seq(n_ubatch);
371371
} else {
372-
ubatch = batch_allocr->split_equal(n_ubatch);
372+
ubatch = balloc.split_equal(n_ubatch);
373373
}
374374

375375
if (ubatch.n_tokens == 0) {
@@ -824,9 +824,9 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
824824

825825
seq_rm(dest_seq_id, -1, -1);
826826

827-
llama_batch_allocr batch_allocr;
827+
llama_batch_allocr balloc;
828828

829-
llama_ubatch ubatch = batch_allocr.ubatch_reserve(cell_count, 1);
829+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
830830

831831
for (uint32_t i = 0; i < cell_count; ++i) {
832832
llama_pos pos;

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
3030
//
3131

3232
llama_memory_state_ptr init_batch(
33-
llama_batch_allocr * batch_allocr,
33+
llama_batch_allocr & balloc,
3434
uint32_t n_ubatch,
3535
bool embd_all) override;
3636

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595
return kv_swa->seq_pos_max(seq_id);
9696
}
9797

98-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr * batch_allocr, uint32_t n_ubatch, bool embd_all) {
98+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
9999
GGML_UNUSED(embd_all);
100100

101101
// first try simple split
102102
do {
103-
batch_allocr->split_reset();
103+
balloc.split_reset();
104104

105105
std::vector<llama_ubatch> ubatches;
106106
while (true) {
107-
auto ubatch = batch_allocr->split_simple(n_ubatch);
107+
auto ubatch = balloc.split_simple(n_ubatch);
108108

109109
if (ubatch.n_tokens == 0) {
110110
break;
@@ -131,11 +131,11 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
131131

132132
// if it fails, try equal split
133133
do {
134-
batch_allocr->split_reset();
134+
balloc.split_reset();
135135

136136
std::vector<llama_ubatch> ubatches;
137137
while (true) {
138-
auto ubatch = batch_allocr->split_equal(n_ubatch);
138+
auto ubatch = balloc.split_equal(n_ubatch);
139139

140140
if (ubatch.n_tokens == 0) {
141141
break;

src/llama-kv-cache-unified-iswa.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
3232
//
3333

3434
llama_memory_state_ptr init_batch(
35-
llama_batch_allocr * batch_allocr,
35+
llama_batch_allocr & balloc,
3636
uint32_t n_ubatch,
3737
bool embd_all) override;
3838

src/llama-kv-cache-unified.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,17 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
308308
}
309309

310310
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311-
llama_batch_allocr * batch_allocr,
311+
llama_batch_allocr & balloc,
312312
uint32_t n_ubatch,
313313
bool embd_all) {
314314
GGML_UNUSED(embd_all);
315315

316316
do {
317-
batch_allocr->split_reset();
317+
balloc.split_reset();
318318

319319
std::vector<llama_ubatch> ubatches;
320320
while (true) {
321-
auto ubatch = batch_allocr->split_simple(n_ubatch);
321+
auto ubatch = balloc.split_simple(n_ubatch);
322322

323323
if (ubatch.n_tokens == 0) {
324324
break;
@@ -1505,9 +1505,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15051505

15061506
seq_rm(dest_seq_id, -1, -1);
15071507

1508-
llama_batch_allocr batch_allocr;
1508+
llama_batch_allocr balloc;
15091509

1510-
llama_ubatch ubatch = batch_allocr.ubatch_reserve(cell_count, 1);
1510+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
15111511

15121512
for (uint32_t i = 0; i < cell_count; ++i) {
15131513
llama_pos pos;

src/llama-kv-cache-unified.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class llama_kv_cache_unified : public llama_memory_i {
5757
//
5858

5959
llama_memory_state_ptr init_batch(
60-
llama_batch_allocr * batch_allocr,
60+
llama_batch_allocr & balloc,
6161
uint32_t n_ubatch,
6262
bool embd_all) override;
6363

src/llama-memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct llama_memory_i {
7070
// return a state object containing the ubatches and KV cache state required to process them
7171
// check the llama_memory_state_i::get_status() for the result
7272
virtual llama_memory_state_ptr init_batch(
73-
llama_batch_allocr * batch_allocr,
73+
llama_batch_allocr & balloc,
7474
uint32_t n_ubatch,
7575
bool embd_all) = 0;
7676

0 commit comments

Comments
 (0)