Skip to content

Commit 372fa3a

Browse files
committed
cont : enc should work now, next is dec
ggml-ci
1 parent f5e8020 commit 372fa3a

File tree

5 files changed

+293
-217
lines changed

5 files changed

+293
-217
lines changed

src/llama-context.cpp

Lines changed: 124 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,64 @@
1010
#include <stdexcept>
1111
#include <cinttypes>
1212

13+
//
14+
// helpers
15+
//
16+
17+
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
18+
// TODO move to hparams if a T5 variant appears that uses a different value
19+
const int64_t max_distance = 128;
20+
21+
if (bidirectional) {
22+
n_buckets >>= 1;
23+
}
24+
25+
const int64_t max_exact = n_buckets >> 1;
26+
27+
int32_t relative_position = x - y;
28+
int32_t relative_bucket = 0;
29+
30+
if (bidirectional) {
31+
relative_bucket += (relative_position > 0) * n_buckets;
32+
relative_position = abs(relative_position);
33+
} else {
34+
relative_position = -std::min<int32_t>(relative_position, 0);
35+
}
36+
37+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
38+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
39+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
40+
41+
return relative_bucket;
42+
}
43+
1344
//
1445
// llama_context
1546
//
1647

1748
llama_context::llama_context(
1849
const llama_model & model,
19-
const llama_context_params & params,
50+
llama_context_params params,
2051
llama_graph_type gtype) :
2152
llama_graph_i(gtype),
2253
model(model) {
23-
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
54+
LLAMA_LOG_INFO("%s: constructing llama_context, gtype = %d\n", __func__, gtype);
2455

2556
t_start_us = model.t_start_us;
2657
t_load_us = model.t_load_us;
2758

59+
switch (gtype) {
60+
case LLAMA_GRAPH_TYPE_DEFAULT:
61+
case LLAMA_GRAPH_TYPE_DECODER:
62+
{
63+
} break;
64+
case LLAMA_GRAPH_TYPE_ENCODER:
65+
{
66+
params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;
67+
params.embeddings = true;
68+
} break;
69+
}
70+
2871
const auto & hparams = model.hparams;
2972

3073
cparams.n_seq_max = std::max(1u, params.n_seq_max);
@@ -45,20 +88,6 @@ llama_context::llama_context(
4588
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
4689
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
4790

48-
// with causal attention, the batch size is limited by the context size
49-
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
50-
51-
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
52-
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
53-
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
54-
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
55-
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
56-
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
57-
cparams.n_batch = GGML_KQ_MASK_PAD;
58-
}
59-
60-
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
61-
6291
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
6392
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
6493
hparams.n_ctx_train;
@@ -95,13 +124,28 @@ llama_context::llama_context(
95124
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
96125
}
97126

127+
// with causal attention, the batch size is limited by the context size
128+
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
129+
130+
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
131+
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
132+
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
133+
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
134+
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
135+
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
136+
cparams.n_batch = GGML_KQ_MASK_PAD;
137+
}
138+
139+
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
140+
98141
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
99142

100143
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
101144
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
102145
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
103146
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
104147
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
148+
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
105149
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
106150
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
107151
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -1207,6 +1251,23 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
12071251
}
12081252
}
12091253

1254+
if (inp.pos_bucket) {
1255+
const int64_t n_tokens = ubatch.n_tokens;
1256+
1257+
GGML_ASSERT(ggml_backend_buffer_is_host(inp.pos_bucket->buffer));
1258+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
1259+
1260+
int32_t * data = (int32_t *) inp.pos_bucket->data;
1261+
1262+
for (int h = 0; h < 1; ++h) {
1263+
for (int j = 0; j < n_tokens; ++j) {
1264+
for (int i = 0; i < n_tokens; ++i) {
1265+
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, true);
1266+
}
1267+
}
1268+
}
1269+
}
1270+
12101271
GGML_ASSERT(
12111272
// (!a || b) is a logical implication (a -> b)
12121273
// !hparams.causal_attn -> !cparams.causal_attn
@@ -1604,6 +1665,15 @@ ggml_tensor * llama_context::build_inp_pos(
16041665
return inp.pos;
16051666
}
16061667

1668+
ggml_tensor * llama_context::build_inp_pos_bucket(
1669+
ggml_context * ctx0,
1670+
int32_t n_tokens) {
1671+
inp.pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1672+
ggml_set_input(inp.pos_bucket);
1673+
1674+
return inp.pos_bucket;
1675+
}
1676+
16071677
ggml_tensor * llama_context::build_inp_out_ids(
16081678
ggml_context * ctx0) {
16091679
const int32_t n_out_ids = n_outputs;
@@ -1656,6 +1726,7 @@ ggml_tensor * llama_context::build_attn(
16561726
ggml_tensor * q_cur,
16571727
ggml_tensor * k_cur,
16581728
ggml_tensor * v_cur,
1729+
ggml_tensor * kq_b,
16591730
int32_t n_tokens,
16601731
float kq_scale,
16611732
int il) {
@@ -1690,6 +1761,8 @@ ggml_tensor * llama_context::build_attn(
16901761
GGML_UNUSED(model);
16911762
GGML_UNUSED(n_ctx);
16921763

1764+
GGML_ASSERT(kq_b == nullptr);
1765+
16931766
struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3));
16941767
v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv);
16951768

@@ -1720,10 +1793,14 @@ ggml_tensor * llama_context::build_attn(
17201793

17211794
if (hparams.attn_soft_cap) {
17221795
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1723-
kq = ggml_tanh(ctx0, kq);
1796+
kq = ggml_tanh (ctx0, kq);
17241797
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
17251798
}
17261799

1800+
if (kq_b) {
1801+
kq = ggml_add(ctx0, kq, kq_b);
1802+
}
1803+
17271804
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
17281805
//cb(kq, "kq_soft_max_ext", il);
17291806

@@ -2281,7 +2358,7 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
22812358

22822359
llama_context_kv_self::llama_context_kv_self(
22832360
const llama_model & model,
2284-
const llama_context_params & params,
2361+
llama_context_params params,
22852362
llama_graph_type gtype) :
22862363
llama_context(model, params, gtype),
22872364
kv_self(model.hparams) {
@@ -3053,53 +3130,19 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
30533130
}
30543131
}
30553132

3056-
if (inp_pos_bucket) {
3133+
if (inp.self_pos_bucket) {
30573134
const int64_t n_tokens = ubatch.n_tokens;
30583135

3059-
GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer));
3136+
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_pos_bucket->buffer));
30603137
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
30613138

3062-
static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
3063-
// TODO move to hparams if a T5 variant appears that uses a different value
3064-
const int64_t max_distance = 128;
3065-
3066-
if (bidirectional) {
3067-
n_buckets >>= 1;
3068-
}
3139+
int32_t * data = (int32_t *) inp.self_pos_bucket->data;
30693140

3070-
const int64_t max_exact = n_buckets >> 1;
3071-
3072-
int32_t relative_position = x - y;
3073-
int32_t relative_bucket = 0;
3074-
if (bidirectional) {
3075-
relative_bucket += (relative_position > 0) * n_buckets;
3076-
relative_position = abs(relative_position);
3077-
} else {
3078-
relative_position = -std::min<int32_t>(relative_position, 0);
3079-
}
3080-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
3081-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
3082-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
3083-
return relative_bucket;
3084-
};
3085-
3086-
int32_t * data = (int32_t *) inp_pos_bucket->data;
3087-
3088-
if (!is_encoding) {
3089-
const int64_t n_kv = kv_self.n;
3090-
for (int h = 0; h < 1; ++h) {
3091-
for (int j = 0; j < n_tokens; ++j) {
3092-
for (int i = 0; i < n_kv; ++i) {
3093-
data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
3094-
}
3095-
}
3096-
}
3097-
} else {
3098-
for (int h = 0; h < 1; ++h) {
3099-
for (int j = 0; j < n_tokens; ++j) {
3100-
for (int i = 0; i < n_tokens; ++i) {
3101-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
3102-
}
3141+
const int64_t n_kv = kv_self.n;
3142+
for (int h = 0; h < 1; ++h) {
3143+
for (int j = 0; j < n_tokens; ++j) {
3144+
for (int i = 0; i < n_kv; ++i) {
3145+
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false);
31033146
}
31043147
}
31053148
}
@@ -3146,7 +3189,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
31463189

31473190
ggml_cgraph * llama_context_kv_self::graph_init() {
31483191
inp_embd_enc = nullptr;
3149-
inp_pos_bucket = nullptr;
31503192
inp_kq_mask_cross = nullptr;
31513193

31523194
inp = {};
@@ -3161,6 +3203,17 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0)
31613203
return inp.self_k_shift;
31623204
}
31633205

3206+
ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
3207+
ggml_context * ctx0,
3208+
int32_t n_tokens) {
3209+
const auto n_kv = kv_self.n;
3210+
3211+
inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
3212+
ggml_set_input(inp.self_pos_bucket);
3213+
3214+
return inp.self_pos_bucket;
3215+
}
3216+
31643217
void llama_context_kv_self::build_attn_inp(
31653218
ggml_context * ctx0,
31663219
int32_t n_tokens,
@@ -3199,6 +3252,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
31993252
ggml_tensor * q_cur,
32003253
ggml_tensor * k_cur,
32013254
ggml_tensor * v_cur,
3255+
ggml_tensor * kq_b,
32023256
int32_t n_tokens,
32033257
float kq_scale,
32043258
int il) {
@@ -3293,6 +3347,8 @@ ggml_tensor * llama_context_kv_self::build_attn(
32933347
GGML_UNUSED(model);
32943348
GGML_UNUSED(n_ctx);
32953349

3350+
GGML_ASSERT(kq_b == nullptr);
3351+
32963352
// split cached v into n_head heads (not transposed)
32973353
struct ggml_tensor * v =
32983354
ggml_view_3d(ctx0, kv_self.v_l[il],
@@ -3329,10 +3385,14 @@ ggml_tensor * llama_context_kv_self::build_attn(
33293385

33303386
if (hparams.attn_soft_cap) {
33313387
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
3332-
kq = ggml_tanh(ctx0, kq);
3388+
kq = ggml_tanh (ctx0, kq);
33333389
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
33343390
}
33353391

3392+
if (kq_b) {
3393+
kq = ggml_add(ctx0, kq, kq_b);
3394+
}
3395+
33363396
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
33373397
//cb(kq, "kq_soft_max_ext", il);
33383398

@@ -3753,7 +3813,7 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
37533813

37543814
llama_context_recurrent::llama_context_recurrent(
37553815
const llama_model & model,
3756-
const llama_context_params & params,
3816+
llama_context_params params,
37573817
llama_graph_type gtype) :
37583818
llama_context(model, params, gtype),
37593819
kv_self(model.hparams) {
@@ -4629,7 +4689,7 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
46294689

46304690
llama_context_enc_dec::llama_context_enc_dec(
46314691
const llama_model & model,
4632-
const llama_context_params & params) :
4692+
llama_context_params params) :
46334693
llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
46344694
ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
46354695
LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__);

0 commit comments

Comments
 (0)