10
10
#include < stdexcept>
11
11
#include < cinttypes>
12
12
13
+ //
14
+ // helpers
15
+ //
16
+
17
+ static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
18
+ // TODO move to hparams if a T5 variant appears that uses a different value
19
+ const int64_t max_distance = 128 ;
20
+
21
+ if (bidirectional) {
22
+ n_buckets >>= 1 ;
23
+ }
24
+
25
+ const int64_t max_exact = n_buckets >> 1 ;
26
+
27
+ int32_t relative_position = x - y;
28
+ int32_t relative_bucket = 0 ;
29
+
30
+ if (bidirectional) {
31
+ relative_bucket += (relative_position > 0 ) * n_buckets;
32
+ relative_position = abs (relative_position);
33
+ } else {
34
+ relative_position = -std::min<int32_t >(relative_position, 0 );
35
+ }
36
+
37
+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
38
+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
39
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
40
+
41
+ return relative_bucket;
42
+ }
43
+
13
44
//
14
45
// llama_context
15
46
//
16
47
17
48
llama_context::llama_context (
18
49
const llama_model & model,
19
- const llama_context_params & params,
50
+ llama_context_params params,
20
51
llama_graph_type gtype) :
21
52
llama_graph_i(gtype),
22
53
model(model) {
23
- LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
54
+ LLAMA_LOG_INFO (" %s: constructing llama_context, gtype = %d \n " , __func__, gtype );
24
55
25
56
t_start_us = model.t_start_us ;
26
57
t_load_us = model.t_load_us ;
27
58
59
+ switch (gtype) {
60
+ case LLAMA_GRAPH_TYPE_DEFAULT:
61
+ case LLAMA_GRAPH_TYPE_DECODER:
62
+ {
63
+ } break ;
64
+ case LLAMA_GRAPH_TYPE_ENCODER:
65
+ {
66
+ params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;
67
+ params.embeddings = true ;
68
+ } break ;
69
+ }
70
+
28
71
const auto & hparams = model.hparams ;
29
72
30
73
cparams.n_seq_max = std::max (1u , params.n_seq_max );
@@ -45,20 +88,6 @@ llama_context::llama_context(
45
88
cparams.rope_freq_base = params.rope_freq_base == 0 .0f ? hparams.rope_freq_base_train : params.rope_freq_base ;
46
89
cparams.rope_freq_scale = params.rope_freq_scale == 0 .0f ? hparams.rope_freq_scale_train : params.rope_freq_scale ;
47
90
48
- // with causal attention, the batch size is limited by the context size
49
- cparams.n_batch = hparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
50
-
51
- // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
52
- // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
53
- // ref: https://github.com/ggerganov/llama.cpp/pull/5021
54
- // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
55
- if (cparams.n_batch < GGML_KQ_MASK_PAD) {
56
- LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
57
- cparams.n_batch = GGML_KQ_MASK_PAD;
58
- }
59
-
60
- cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
61
-
62
91
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
63
92
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
64
93
hparams.n_ctx_train ;
@@ -95,13 +124,28 @@ llama_context::llama_context(
95
124
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
96
125
}
97
126
127
+ // with causal attention, the batch size is limited by the context size
128
+ cparams.n_batch = cparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
129
+
130
+ // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
131
+ // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
132
+ // ref: https://github.com/ggerganov/llama.cpp/pull/5021
133
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
134
+ if (cparams.n_batch < GGML_KQ_MASK_PAD) {
135
+ LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
136
+ cparams.n_batch = GGML_KQ_MASK_PAD;
137
+ }
138
+
139
+ cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
140
+
98
141
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
99
142
100
143
LLAMA_LOG_INFO (" %s: n_seq_max = %u\n " , __func__, cparams.n_seq_max );
101
144
LLAMA_LOG_INFO (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
102
145
LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq);
103
146
LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
104
147
LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
148
+ LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
105
149
LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
106
150
LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
107
151
LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
@@ -1207,6 +1251,23 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
1207
1251
}
1208
1252
}
1209
1253
1254
+ if (inp.pos_bucket ) {
1255
+ const int64_t n_tokens = ubatch.n_tokens ;
1256
+
1257
+ GGML_ASSERT (ggml_backend_buffer_is_host (inp.pos_bucket ->buffer ));
1258
+ GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
1259
+
1260
+ int32_t * data = (int32_t *) inp.pos_bucket ->data ;
1261
+
1262
+ for (int h = 0 ; h < 1 ; ++h) {
1263
+ for (int j = 0 ; j < n_tokens; ++j) {
1264
+ for (int i = 0 ; i < n_tokens; ++i) {
1265
+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , true );
1266
+ }
1267
+ }
1268
+ }
1269
+ }
1270
+
1210
1271
GGML_ASSERT (
1211
1272
// (!a || b) is a logical implication (a -> b)
1212
1273
// !hparams.causal_attn -> !cparams.causal_attn
@@ -1604,6 +1665,15 @@ ggml_tensor * llama_context::build_inp_pos(
1604
1665
return inp.pos ;
1605
1666
}
1606
1667
1668
+ ggml_tensor * llama_context::build_inp_pos_bucket (
1669
+ ggml_context * ctx0,
1670
+ int32_t n_tokens) {
1671
+ inp.pos_bucket = ggml_new_tensor_2d (ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
1672
+ ggml_set_input (inp.pos_bucket );
1673
+
1674
+ return inp.pos_bucket ;
1675
+ }
1676
+
1607
1677
ggml_tensor * llama_context::build_inp_out_ids (
1608
1678
ggml_context * ctx0) {
1609
1679
const int32_t n_out_ids = n_outputs;
@@ -1656,6 +1726,7 @@ ggml_tensor * llama_context::build_attn(
1656
1726
ggml_tensor * q_cur,
1657
1727
ggml_tensor * k_cur,
1658
1728
ggml_tensor * v_cur,
1729
+ ggml_tensor * kq_b,
1659
1730
int32_t n_tokens,
1660
1731
float kq_scale,
1661
1732
int il) {
@@ -1690,6 +1761,8 @@ ggml_tensor * llama_context::build_attn(
1690
1761
GGML_UNUSED (model);
1691
1762
GGML_UNUSED (n_ctx);
1692
1763
1764
+ GGML_ASSERT (kq_b == nullptr );
1765
+
1693
1766
struct ggml_tensor * v = ggml_cont (ctx0, ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 ));
1694
1767
v = ggml_reshape_3d (ctx0, v, n_embd_head_v, n_kv, n_head_kv);
1695
1768
@@ -1720,10 +1793,14 @@ ggml_tensor * llama_context::build_attn(
1720
1793
1721
1794
if (hparams.attn_soft_cap ) {
1722
1795
kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
1723
- kq = ggml_tanh (ctx0, kq);
1796
+ kq = ggml_tanh (ctx0, kq);
1724
1797
kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
1725
1798
}
1726
1799
1800
+ if (kq_b) {
1801
+ kq = ggml_add (ctx0, kq, kq_b);
1802
+ }
1803
+
1727
1804
kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
1728
1805
// cb(kq, "kq_soft_max_ext", il);
1729
1806
@@ -2281,7 +2358,7 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
2281
2358
2282
2359
llama_context_kv_self::llama_context_kv_self (
2283
2360
const llama_model & model,
2284
- const llama_context_params & params,
2361
+ llama_context_params params,
2285
2362
llama_graph_type gtype) :
2286
2363
llama_context(model, params, gtype),
2287
2364
kv_self(model.hparams) {
@@ -3053,53 +3130,19 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
3053
3130
}
3054
3131
}
3055
3132
3056
- if (inp_pos_bucket ) {
3133
+ if (inp. self_pos_bucket ) {
3057
3134
const int64_t n_tokens = ubatch.n_tokens ;
3058
3135
3059
- GGML_ASSERT (ggml_backend_buffer_is_host (inp_pos_bucket ->buffer ));
3136
+ GGML_ASSERT (ggml_backend_buffer_is_host (inp. self_pos_bucket ->buffer ));
3060
3137
GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
3061
3138
3062
- static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
3063
- // TODO move to hparams if a T5 variant appears that uses a different value
3064
- const int64_t max_distance = 128 ;
3065
-
3066
- if (bidirectional) {
3067
- n_buckets >>= 1 ;
3068
- }
3139
+ int32_t * data = (int32_t *) inp.self_pos_bucket ->data ;
3069
3140
3070
- const int64_t max_exact = n_buckets >> 1 ;
3071
-
3072
- int32_t relative_position = x - y;
3073
- int32_t relative_bucket = 0 ;
3074
- if (bidirectional) {
3075
- relative_bucket += (relative_position > 0 ) * n_buckets;
3076
- relative_position = abs (relative_position);
3077
- } else {
3078
- relative_position = -std::min<int32_t >(relative_position, 0 );
3079
- }
3080
- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
3081
- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
3082
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
3083
- return relative_bucket;
3084
- };
3085
-
3086
- int32_t * data = (int32_t *) inp_pos_bucket->data ;
3087
-
3088
- if (!is_encoding) {
3089
- const int64_t n_kv = kv_self.n ;
3090
- for (int h = 0 ; h < 1 ; ++h) {
3091
- for (int j = 0 ; j < n_tokens; ++j) {
3092
- for (int i = 0 ; i < n_kv; ++i) {
3093
- data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
3094
- }
3095
- }
3096
- }
3097
- } else {
3098
- for (int h = 0 ; h < 1 ; ++h) {
3099
- for (int j = 0 ; j < n_tokens; ++j) {
3100
- for (int i = 0 ; i < n_tokens; ++i) {
3101
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
3102
- }
3141
+ const int64_t n_kv = kv_self.n ;
3142
+ for (int h = 0 ; h < 1 ; ++h) {
3143
+ for (int j = 0 ; j < n_tokens; ++j) {
3144
+ for (int i = 0 ; i < n_kv; ++i) {
3145
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , false );
3103
3146
}
3104
3147
}
3105
3148
}
@@ -3146,7 +3189,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
3146
3189
3147
3190
ggml_cgraph * llama_context_kv_self::graph_init () {
3148
3191
inp_embd_enc = nullptr ;
3149
- inp_pos_bucket = nullptr ;
3150
3192
inp_kq_mask_cross = nullptr ;
3151
3193
3152
3194
inp = {};
@@ -3161,6 +3203,17 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0)
3161
3203
return inp.self_k_shift ;
3162
3204
}
3163
3205
3206
+ ggml_tensor * llama_context_kv_self::build_inp_pos_bucket (
3207
+ ggml_context * ctx0,
3208
+ int32_t n_tokens) {
3209
+ const auto n_kv = kv_self.n ;
3210
+
3211
+ inp.self_pos_bucket = ggml_new_tensor_2d (ctx0, GGML_TYPE_I32, n_kv, n_tokens);
3212
+ ggml_set_input (inp.self_pos_bucket );
3213
+
3214
+ return inp.self_pos_bucket ;
3215
+ }
3216
+
3164
3217
void llama_context_kv_self::build_attn_inp (
3165
3218
ggml_context * ctx0,
3166
3219
int32_t n_tokens,
@@ -3199,6 +3252,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
3199
3252
ggml_tensor * q_cur,
3200
3253
ggml_tensor * k_cur,
3201
3254
ggml_tensor * v_cur,
3255
+ ggml_tensor * kq_b,
3202
3256
int32_t n_tokens,
3203
3257
float kq_scale,
3204
3258
int il) {
@@ -3293,6 +3347,8 @@ ggml_tensor * llama_context_kv_self::build_attn(
3293
3347
GGML_UNUSED (model);
3294
3348
GGML_UNUSED (n_ctx);
3295
3349
3350
+ GGML_ASSERT (kq_b == nullptr );
3351
+
3296
3352
// split cached v into n_head heads (not transposed)
3297
3353
struct ggml_tensor * v =
3298
3354
ggml_view_3d (ctx0, kv_self.v_l [il],
@@ -3329,10 +3385,14 @@ ggml_tensor * llama_context_kv_self::build_attn(
3329
3385
3330
3386
if (hparams.attn_soft_cap ) {
3331
3387
kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
3332
- kq = ggml_tanh (ctx0, kq);
3388
+ kq = ggml_tanh (ctx0, kq);
3333
3389
kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
3334
3390
}
3335
3391
3392
+ if (kq_b) {
3393
+ kq = ggml_add (ctx0, kq, kq_b);
3394
+ }
3395
+
3336
3396
kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
3337
3397
// cb(kq, "kq_soft_max_ext", il);
3338
3398
@@ -3753,7 +3813,7 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
3753
3813
3754
3814
llama_context_recurrent::llama_context_recurrent (
3755
3815
const llama_model & model,
3756
- const llama_context_params & params,
3816
+ llama_context_params params,
3757
3817
llama_graph_type gtype) :
3758
3818
llama_context(model, params, gtype),
3759
3819
kv_self(model.hparams) {
@@ -4629,7 +4689,7 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
4629
4689
4630
4690
llama_context_enc_dec::llama_context_enc_dec (
4631
4691
const llama_model & model,
4632
- const llama_context_params & params) :
4692
+ llama_context_params params) :
4633
4693
llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
4634
4694
ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
4635
4695
LLAMA_LOG_INFO (" %s: constructing llama_context_enc_dec\n " , __func__);
0 commit comments