Skip to content

Commit b9912ac

Browse files
authored
batch : auto-gen positions + verify multi-sequence input (#14177)
* batch : verify multi-sequence input batches ggml-ci * cont : auto-gen positions + verify multi-seq input ggml-ci * cont : first print debug info, then perform validation ggml-ci * cont : fix position auto-gen + add comments ggml-ci
1 parent 00ba772 commit b9912ac

File tree

5 files changed

+155
-26
lines changed

5 files changed

+155
-26
lines changed

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,14 @@ extern "C" {
243243

244244
typedef bool (*llama_progress_callback)(float progress, void * user_data);
245245

246-
// Input data for llama_decode
246+
// Input data for llama_encode/llama_decode
247247
// A llama_batch object can contain input about one or many sequences
248248
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
249249
//
250250
// - token : the token ids of the input (used when embd is NULL)
251251
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
252252
// - pos : the positions of the respective token in the sequence
253-
// (if set to NULL, the token position will be tracked automatically by llama_decode)
253+
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
254254
// - seq_id : the sequence to which the respective token belongs
255255
// (if set to NULL, the sequence ID will be assumed to be 0)
256256
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output

src/llama-batch.cpp

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44
#include "llama-cparams.h"
55
#include "llama-vocab.h"
6+
#include "llama-memory.h"
67

78
#include <cassert>
89
#include <cstring>
@@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
287288
llama_batch_allocr::llama_batch_allocr() {
288289
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
289290
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291+
292+
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
293+
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
294+
for (auto & cur : seq_cpl) {
295+
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
296+
}
290297
}
291298

292-
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
299+
bool llama_batch_allocr::init(
300+
const llama_batch & batch_inp,
301+
const llama_vocab & vocab,
302+
const llama_memory_i * memory) {
293303
clear();
294304

295305
batch = batch_inp;
296306

297307
GGML_ASSERT(batch.n_tokens > 0);
298308

299-
if (!batch.pos) {
300-
if (batch.seq_id) {
301-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
302-
return false;
303-
}
304-
}
309+
//
310+
// validate input batch
311+
//
305312

306313
if (batch.token) {
307314
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
323330
}
324331
}
325332

326-
if (!batch.pos) {
327-
assert(p0 >= 0);
328-
pos.resize(batch.n_tokens);
329-
for (int32_t i = 0; i < batch.n_tokens; i++) {
330-
pos[i] = p0 + i;
331-
}
332-
batch.pos = pos.data();
333-
}
333+
//
334+
// auto-generate missing fields
335+
//
334336

335337
if (!batch.n_seq_id) {
336338
n_seq_id.resize(batch.n_tokens);
@@ -349,20 +351,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
349351
batch.seq_id = seq_id.data();
350352
}
351353

354+
if (!batch.pos) {
355+
pos.resize(batch.n_tokens);
356+
357+
// initialize the starting position for each sequence based on the positions in the memory
358+
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
360+
if (!memory) {
361+
p0[s] = 0;
362+
} else {
363+
p0[s] = memory->seq_pos_max(s) + 1;
364+
}
365+
}
366+
367+
for (int32_t i = 0; i < batch.n_tokens; i++) {
368+
const llama_seq_id seq_id = batch.seq_id[i][0];
369+
370+
pos[i] = p0[seq_id];
371+
372+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
373+
p0[batch.seq_id[i][s]] = pos[i] + 1;
374+
}
375+
}
376+
377+
batch.pos = pos.data();
378+
}
379+
352380
if (!batch.logits) {
353381
// by default return the output only for the last token
354382
output.resize(batch.n_tokens);
355383
output[output.size() - 1] = true;
356384
batch.logits = output.data();
357385
}
358386

387+
//
388+
// compute stats
389+
//
390+
359391
for (int32_t i = 0; i < batch.n_tokens; ++i) {
360392
n_outputs += batch.logits[i] != 0;
361393
}
362394

395+
// determine coupled sequences
396+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
397+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
398+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
399+
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
400+
401+
if (s > 0) {
402+
const llama_seq_id s0 = batch.seq_id[i][0];
403+
const llama_seq_id s1 = batch.seq_id[i][s];
404+
405+
// mark that sequence s1 is coupled to s0
406+
seq_cpl[s1][s0] = true;
407+
408+
// note: the other way around is not necessary for now
409+
//seq_cpl[s0][s1] = true;
410+
}
411+
}
412+
}
413+
363414
if (debug > 0) {
364-
LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
365-
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
415+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
416+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
366417
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
367418
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
368419
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
@@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
404455
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
405456
}
406457
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
458+
459+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
460+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
461+
if (seq_pos[s0].empty()) {
462+
continue;
463+
}
464+
465+
std::stringstream ss;
466+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
467+
if (seq_cpl[s0][s1]) {
468+
ss << s1 << " ";
469+
}
470+
}
471+
472+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
473+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
474+
}
475+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
476+
}
477+
}
478+
479+
//
480+
// consistency checks
481+
//
482+
483+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
484+
if (seq_pos[s].empty()) {
485+
continue;
486+
}
487+
488+
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
489+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
490+
return false;
491+
}
492+
493+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
494+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
495+
return false;
496+
}
497+
}
498+
499+
if (memory) {
500+
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501+
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
502+
if (seq_cpl[s0][s1]) {
503+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
504+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
505+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
506+
return false;
507+
}
508+
}
509+
}
407510
}
408511
}
409512

@@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
418521
return n_outputs;
419522
}
420523

524+
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
525+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
526+
}
527+
528+
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
529+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
530+
}
531+
421532
void llama_batch_allocr::clear() {
422533
n_outputs = 0;
423534

@@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
426537
n_seq_id.clear();
427538
seq_id.clear();
428539
output.clear();
540+
541+
for (auto & cur : seq_pos) {
542+
cur.clear();
543+
}
544+
545+
for (auto & cur : seq_cpl) {
546+
std::fill(cur.begin(), cur.end(), false);
547+
}
429548
}
430549

431550
//

src/llama-batch.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <array>
66
#include <vector>
7+
#include <set>
78

89
// very similar to llama_batch,
910
// but has more metadata about sequences
@@ -77,18 +78,25 @@ struct llama_sbatch {
7778
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
7879
};
7980

80-
// temporary allocate memory for the input batch if needed
81+
// a helper for sanitizing and fulfilling a batch
8182
class llama_batch_allocr {
8283
public:
8384
llama_batch_allocr();
8485

85-
// optionally fulfill the batch returned by llama_batch_get_one
86-
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
86+
// sanitize and auto-gen missing data in the input batch
87+
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
88+
bool init(
89+
const llama_batch & batch_inp,
90+
const llama_vocab & vocab,
91+
const llama_memory_i * memory);
8792

8893
const llama_batch & get_batch() const;
8994

9095
uint32_t get_n_outputs() const;
9196

97+
llama_pos seq_pos_min(llama_seq_id seq_id) const;
98+
llama_pos seq_pos_max(llama_seq_id seq_id) const;
99+
92100
private:
93101
void clear();
94102

@@ -103,5 +111,8 @@ class llama_batch_allocr {
103111
std::vector<llama_seq_id *> seq_id;
104112
std::vector<int8_t> output;
105113

114+
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
115+
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
116+
106117
int debug;
107118
};

src/llama-context.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
727727
return -1;
728728
}
729729

730-
// temporary allocate memory for the input batch if needed
731730
// note: during encode, we always pass the full sequence starting from pos = 0
732-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
731+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
733732
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
734733
return -1;
735734
}
@@ -895,8 +894,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
895894
return -1;
896895
}
897896

898-
// temporary allocate memory for the input batch if needed
899-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
897+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
900898
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
901899
return -1;
902900
}

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cstdint>
66

7+
// TODO: rename to something shorter
78
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
89

910
struct llama_cparams {

0 commit comments

Comments
 (0)