Skip to content

Commit b5aa627

Browse files
committed
context : simplify sbatch logic
ggml-ci
1 parent c21671a commit b5aa627

File tree

6 files changed

+65
-65
lines changed

6 files changed

+65
-65
lines changed

src/llama-batch.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189189
return ubatch;
190190
}
191191

192-
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
192+
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193193
GGML_ASSERT(batch.n_tokens >= 0);
194194
this->batch = &batch;
195195
this->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
203203
for (size_t i = 0; i < n_tokens; ++i) {
204204
ids[i] = i;
205205
}
206+
206207
if (simple_split) {
207208
seq.resize(1);
208209
llama_sbatch_seq & s = seq[0];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
212213
s.length = n_tokens;
213214
return;
214215
}
216+
215217
std::sort(ids.begin(), ids.end(),
216218
[&batch](size_t a, size_t b) {
217219
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
239241
return n_seq_a > n_seq_b;
240242
}
241243
);
244+
242245
// init seq
243246
llama_sbatch_seq * last_seq = nullptr;
244247

@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
262265
seq.push_back(new_seq);
263266
last_seq = &seq.back();
264267
}
268+
265269
// keep shared prompts first at the end, then sort by length descending.
266270
std::sort(seq.begin(), seq.end(),
267271
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {

src/llama-batch.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ struct llama_sbatch {
7070
// sequence-wise split
7171
llama_ubatch split_seq(size_t n_ubatch);
7272

73-
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
73+
llama_sbatch() = default;
74+
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
7475
};
7576

7677
// temporary allocate memory for the input batch if needed

src/llama-context.cpp

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
810810
}
811811

812812
float * llama_context::get_logits() {
813-
// reorder logits for backward compatibility
814-
output_reorder();
815-
816813
return logits;
817814
}
818815

@@ -855,9 +852,6 @@ float * llama_context::get_logits_ith(int32_t i) {
855852
}
856853

857854
float * llama_context::get_embeddings() {
858-
// reorder embeddings for backward compatibility
859-
output_reorder();
860-
861855
return embd;
862856
}
863857

@@ -1039,7 +1033,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10391033

10401034
const int64_t n_embd = hparams.n_embd;
10411035

1042-
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1036+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
10431037

10441038
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
10451039

@@ -1230,13 +1224,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12301224
n_outputs_all = 1;
12311225
}
12321226

1233-
const bool logits_all = n_outputs_all == n_tokens_all;
1234-
1235-
const bool is_recurrent = llama_model_is_recurrent(&model);
1236-
1237-
sbatch.from_batch(batch, n_embd,
1238-
/* simple_split */ !is_recurrent,
1239-
/* logits_all */ logits_all);
1227+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
12401228

12411229
// reserve output buffer
12421230
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1393,18 +1381,52 @@ int llama_context::decode(llama_batch & inp_batch) {
13931381
{
13941382
bool sorted_output = true;
13951383

1396-
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1384+
auto & out_ids = sbatch.out_ids;
1385+
1386+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
13971387

13981388
for (int64_t i = 0; i < n_outputs_all; ++i) {
1399-
int64_t out_id = sbatch.out_ids[i];
1389+
int64_t out_id = out_ids[i];
14001390
output_ids[out_id] = i;
14011391
if (out_id != i) {
14021392
sorted_output = false;
14031393
}
14041394
}
14051395

1406-
if (sorted_output) {
1407-
sbatch.out_ids.clear();
1396+
// make the outputs have the same order they had in the user-provided batch
1397+
// note: this is mostly relevant for recurrent models atm
1398+
if (!sorted_output) {
1399+
const uint32_t n_vocab = model.vocab.n_tokens();
1400+
const uint32_t n_embd = model.hparams.n_embd;
1401+
1402+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
1403+
1404+
// TODO: is there something more efficient which also minimizes swaps?
1405+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1406+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1407+
int32_t j_min = i;
1408+
for (int32_t j = i + 1; j < n_outputs; ++j) {
1409+
if (out_ids[j] < out_ids[j_min]) {
1410+
j_min = j;
1411+
}
1412+
}
1413+
if (j_min == i) { continue; }
1414+
std::swap(out_ids[i], out_ids[j_min]);
1415+
if (logits_size > 0) {
1416+
for (uint32_t k = 0; k < n_vocab; k++) {
1417+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1418+
}
1419+
}
1420+
if (embd_size > 0) {
1421+
for (uint32_t k = 0; k < n_embd; k++) {
1422+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1423+
}
1424+
}
1425+
}
1426+
std::fill(output_ids.begin(), output_ids.end(), -1);
1427+
for (int32_t i = 0; i < n_outputs; ++i) {
1428+
output_ids[out_ids[i]] = i;
1429+
}
14081430
}
14091431
}
14101432

@@ -1515,44 +1537,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
15151537
return n_outputs_max;
15161538
}
15171539

1518-
void llama_context::output_reorder() {
1519-
auto & out_ids = sbatch.out_ids;
1520-
if (!out_ids.empty()) {
1521-
const uint32_t n_vocab = model.vocab.n_tokens();
1522-
const uint32_t n_embd = model.hparams.n_embd;
1523-
1524-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
1525-
1526-
// TODO: is there something more efficient which also minimizes swaps?
1527-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1528-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
1529-
int32_t j_min = i;
1530-
for (int32_t j = i + 1; j < n_outputs; ++j) {
1531-
if (out_ids[j] < out_ids[j_min]) {
1532-
j_min = j;
1533-
}
1534-
}
1535-
if (j_min == i) { continue; }
1536-
std::swap(out_ids[i], out_ids[j_min]);
1537-
if (logits_size > 0) {
1538-
for (uint32_t k = 0; k < n_vocab; k++) {
1539-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1540-
}
1541-
}
1542-
if (embd_size > 0) {
1543-
for (uint32_t k = 0; k < n_embd; k++) {
1544-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1545-
}
1546-
}
1547-
}
1548-
std::fill(output_ids.begin(), output_ids.end(), -1);
1549-
for (int32_t i = 0; i < n_outputs; ++i) {
1550-
output_ids[out_ids[i]] = i;
1551-
}
1552-
out_ids.clear();
1553-
}
1554-
}
1555-
15561540
//
15571541
// graph
15581542
//
@@ -1993,8 +1977,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
19931977
{
19941978
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
19951979

1996-
output_reorder();
1997-
19981980
const auto n_outputs = this->n_outputs;
19991981
const auto & output_ids = this->output_ids;
20001982

src/llama-context.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,6 @@ struct llama_context {
137137
// Returns max number of outputs for which space was reserved.
138138
int32_t output_reserve(int32_t n_outputs);
139139

140-
// make the outputs have the same order they had in the user-provided batch
141-
// TODO: maybe remove this
142-
void output_reorder();
143-
144140
//
145141
// graph
146142
//
@@ -197,7 +193,6 @@ struct llama_context {
197193
llama_cparams cparams;
198194
llama_adapter_cvec cvec;
199195
llama_adapter_loras loras;
200-
llama_sbatch sbatch;
201196

202197
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203198

src/llama-kv-cache.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ bool llama_kv_cache_unified::find_slot(
476476
return true;
477477
}
478478

479+
llama_sbatch llama_kv_cache_unified::sbatch_init(
480+
const llama_batch & batch,
481+
bool logits_all) {
482+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
483+
}
484+
479485
llama_ubatch llama_kv_cache_unified::ubatch_next(
480486
llama_sbatch & sbatch,
481487
uint32_t n_ubatch,
@@ -1547,6 +1553,12 @@ bool llama_kv_cache_recurrent::find_slot(
15471553
return n >= n_seqs;
15481554
}
15491555

1556+
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
1557+
const llama_batch & batch,
1558+
bool logits_all) {
1559+
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
1560+
}
1561+
15501562
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
15511563
if (embd_pooled) {
15521564
// Pooled embeddings cannot be split across ubatches (yet)

src/llama-kv-cache.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ struct llama_kv_cache : public llama_memory_i {
4545

4646
virtual bool find_slot(const llama_ubatch & batch) = 0;
4747

48+
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
49+
4850
// different KV caches require different batch splitting strategies
4951
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
5052

@@ -143,6 +145,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
143145
// to the first cell of the slot.
144146
bool find_slot(const llama_ubatch & batch) override;
145147

148+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
149+
146150
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147151

148152
static uint32_t get_padding(const llama_cparams & cparams);
@@ -269,6 +273,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
269273
// to the first cell of the slot.
270274
bool find_slot(const llama_ubatch & batch) override;
271275

276+
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
277+
272278
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
273279

274280
// find how many cells are currently in use

0 commit comments

Comments
 (0)