Skip to content

Commit c311ac6

Browse files
authored
cparams : rename LLAMA_MAX_PARALLEL_SEQUENCES to LLAMA_MAX_SEQ (#14188)
ggml-ci
1 parent b9912ac commit c311ac6

File tree

6 files changed

+29
-30
lines changed

6 files changed

+29
-30
lines changed

src/llama-batch.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ llama_batch_allocr::llama_batch_allocr() {
289289
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290290
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291291

292-
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
293-
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
292+
seq_pos.resize(LLAMA_MAX_SEQ);
293+
seq_cpl.resize(LLAMA_MAX_SEQ);
294294
for (auto & cur : seq_cpl) {
295-
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
295+
cur.resize(LLAMA_MAX_SEQ);
296296
}
297297
}
298298

@@ -322,8 +322,8 @@ bool llama_batch_allocr::init(
322322
if (batch.seq_id) {
323323
for (int32_t i = 0; i < batch.n_tokens; ++i) {
324324
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
325-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
326-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
325+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
326+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
327327
return false;
328328
}
329329
}
@@ -355,8 +355,8 @@ bool llama_batch_allocr::init(
355355
pos.resize(batch.n_tokens);
356356

357357
// initialize the starting position for each sequence based on the positions in the memory
358-
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
358+
llama_pos p0[LLAMA_MAX_SEQ];
359+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
360360
if (!memory) {
361361
p0[s] = 0;
362362
} else {
@@ -480,7 +480,7 @@ bool llama_batch_allocr::init(
480480
// consistency checks
481481
//
482482

483-
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
483+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
484484
if (seq_pos[s].empty()) {
485485
continue;
486486
}
@@ -497,8 +497,8 @@ bool llama_batch_allocr::init(
497497
}
498498

499499
if (memory) {
500-
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501-
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
500+
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
501+
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
502502
if (seq_cpl[s0][s1]) {
503503
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
504504
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {

src/llama-context.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ llama_context::llama_context(
2929
const auto & hparams = model.hparams;
3030

3131
cparams.n_seq_max = std::max(1u, params.n_seq_max);
32-
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
33-
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
32+
if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33+
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
3434
}
3535

3636
cparams.n_threads = params.n_threads;
@@ -1023,8 +1023,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10231023

10241024
if (!res) {
10251025
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1026-
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
1027-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1026+
llama_pos pos_min[LLAMA_MAX_SEQ];
1027+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10281028
pos_min[s] = std::numeric_limits<llama_pos>::max();
10291029
}
10301030

@@ -1035,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10351035
pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
10361036
}
10371037

1038-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
1038+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
10391039
if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
10401040
continue;
10411041
}

src/llama-cparams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "llama-cparams.h"
22

33
size_t llama_max_parallel_sequences(void) {
4-
return LLAMA_MAX_PARALLEL_SEQUENCES;
4+
return LLAMA_MAX_SEQ;
55
}

src/llama-cparams.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
#include <cstdint>
66

7-
// TODO: rename to something shorter
8-
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
7+
#define LLAMA_MAX_SEQ 64
98

109
struct llama_cparams {
1110
uint32_t n_ctx; // context size used during inference

src/llama-kv-cache-unified.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
572572
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
573573
}
574574

575-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
575+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
576576
if (cells.seq_pos_min(s) < 0) {
577577
continue;
578578
}
@@ -652,8 +652,8 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
652652

653653
// keep track of the max sequence position that we would overwrite with this ubatch
654654
// for non-SWA cache, this would be always empty
655-
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
656-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
655+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
656+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
657657
seq_pos_max_rm[s] = -1;
658658
}
659659

@@ -684,7 +684,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
684684
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
685685
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
686686
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
687-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
687+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
688688
if (seq_pos_max_rm[s] == -1) {
689689
continue;
690690
}

src/llama-kv-cells.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class llama_kv_cells_unified {
2323

2424
used.clear();
2525

26-
for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
26+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
2727
seq_pos[s].clear();
2828
}
2929
}
@@ -240,7 +240,7 @@ class llama_kv_cells_unified {
240240
llama_seq_id seq_get(uint32_t i) const {
241241
assert(seq[i].count() == 1);
242242

243-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
243+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
244244
if (seq[i].test(s)) {
245245
return s;
246246
}
@@ -253,7 +253,7 @@ class llama_kv_cells_unified {
253253
// return -1 if the sequence is not present
254254
llama_pos seq_pos_min(llama_seq_id seq_id) const {
255255
assert(seq_id >= 0);
256-
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
256+
assert(seq_id < LLAMA_MAX_SEQ);
257257

258258
if (seq_pos[seq_id].empty()) {
259259
return -1;
@@ -266,7 +266,7 @@ class llama_kv_cells_unified {
266266
// return -1 if the sequence is not present
267267
llama_pos seq_pos_max(llama_seq_id seq_id) const {
268268
assert(seq_id >= 0);
269-
assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
269+
assert(seq_id < LLAMA_MAX_SEQ);
270270

271271
if (seq_pos[seq_id].empty()) {
272272
return -1;
@@ -384,20 +384,20 @@ class llama_kv_cells_unified {
384384
//
385385
std::vector<llama_pos> shift;
386386

387-
using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
387+
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
388388

389389
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390390
std::vector<bits_t> seq;
391391

392392
// the set seq_pos[s] tells us which positions are currently present for sequence s
393393
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394-
std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
394+
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
395395

396396
// helper functions for updating `seq_pos`, once cell at a time:
397397

398398
// remove cell i
399399
void seq_pos_rm(uint32_t i) {
400-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
400+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401401
if (seq[i].test(s)) {
402402
seq_pos[s].erase(pos[i]);
403403
}
@@ -406,7 +406,7 @@ class llama_kv_cells_unified {
406406

407407
// add cell i
408408
void seq_pos_add(uint32_t i) {
409-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
409+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410410
if (seq[i].test(s)) {
411411
seq_pos[s].insert(pos[i]);
412412
}

0 commit comments

Comments
 (0)