@@ -1721,50 +1721,67 @@ void llama_context::build_attn_inp(
1721
1721
ggml_tensor * llama_context::build_attn (
1722
1722
ggml_context * ctx0,
1723
1723
ggml_cgraph * gf,
1724
- ggml_tensor * wo,
1725
- ggml_tensor * wo_b,
1726
1724
ggml_tensor * q_cur,
1727
1725
ggml_tensor * k_cur,
1728
1726
ggml_tensor * v_cur,
1729
1727
ggml_tensor * kq_b,
1730
- int32_t n_tokens,
1731
1728
float kq_scale,
1732
1729
int il) {
1733
- const auto & hparams = model. hparams ;
1730
+ GGML_UNUSED (il) ;
1734
1731
1735
- const auto & n_ctx = cparams. n_ctx ;
1732
+ const auto & kq_mask = inp. kq_mask_cnv ;
1736
1733
1737
- // const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il );
1738
- const int64_t n_embd_v_gqa = hparams. n_embd_v_gqa ( il);
1734
+ ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1735
+ // cb(q, "q", il);
1739
1736
1740
- const auto & kq_mask = inp.kq_mask_cnv ;
1737
+ ggml_tensor * k = ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 );
1738
+ // cb(k, "k", il);
1741
1739
1742
- const int64_t n_head = hparams. n_head (il );
1743
- const int64_t n_head_kv = hparams. n_head_kv ( il);
1740
+ ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1741
+ // cb(k, "v", il);
1744
1742
1745
- // const auto & n_embd_head_k = hparams.n_embd_head_k;
1746
- const auto & n_embd_head_v = hparams.n_embd_head_v ;
1743
+ ggml_tensor * cur = build_attn_mha (ctx0, gf, q, k, v, kq_b, kq_mask, false , kq_scale);
1747
1744
1748
- // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1749
- const auto n_kv = n_tokens;
1745
+ return cur;
1746
+ }
1750
1747
1751
- struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1752
- // cb(q, "q", il);
1748
+ ggml_tensor * llama_context::build_attn_mha (
1749
+ ggml_context * ctx0,
1750
+ ggml_cgraph * gf,
1751
+ ggml_tensor * q,
1752
+ ggml_tensor * k,
1753
+ ggml_tensor * v,
1754
+ ggml_tensor * kq_b,
1755
+ ggml_tensor * kq_mask,
1756
+ bool v_trans,
1757
+ float kq_scale) {
1758
+ const auto & hparams = model.hparams ;
1753
1759
1754
- struct ggml_tensor * k = ggml_cont (ctx0, ggml_permute (ctx0, k_cur, 0 , 2 , 1 , 3 ));
1755
- // cb(k, "k", il);
1760
+ // const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1761
+ // const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1762
+
1763
+ // const int64_t n_head = hparams.n_head(il);
1764
+ // const int64_t n_head_kv = hparams.n_head_kv(il);
1765
+
1766
+ // const auto & n_embd_head_k = hparams.n_embd_head_k;
1767
+ // const auto & n_embd_head_v = hparams.n_embd_head_v;
1768
+
1769
+ const auto n_embd_head_v = v_trans ? v->ne [1 ] : v->ne [0 ];
1770
+
1771
+ const auto n_tokens = q->ne [1 ];
1772
+ const auto n_head = q->ne [2 ];
1773
+ const auto n_kv = k->ne [1 ];
1756
1774
1757
1775
struct ggml_tensor * cur;
1758
1776
1759
- // if (cparams.flash_attn) {
1760
- if (false ) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD
1777
+ if (cparams.flash_attn && (n_kv % 256 == 0 ) && kq_b == nullptr ) {
1761
1778
GGML_UNUSED (model);
1762
- GGML_UNUSED (n_ctx);
1763
1779
1764
- GGML_ASSERT (kq_b == nullptr );
1780
+ GGML_ASSERT (kq_b == nullptr && " Flash attention does not support KQ bias yet " );
1765
1781
1766
- struct ggml_tensor * v = ggml_cont (ctx0, ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 ));
1767
- v = ggml_reshape_3d (ctx0, v, n_embd_head_v, n_kv, n_head_kv);
1782
+ if (v_trans) {
1783
+ v = ggml_transpose (ctx0, v);
1784
+ }
1768
1785
1769
1786
cur = ggml_flash_attn_ext (ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
1770
1787
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
@@ -1774,7 +1791,6 @@ ggml_tensor * llama_context::build_attn(
1774
1791
cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
1775
1792
} else {
1776
1793
struct ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1777
- // cb(kq, "kq", il);
1778
1794
1779
1795
// note: this op tends to require high floating point range
1780
1796
// while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1802,22 +1818,17 @@ ggml_tensor * llama_context::build_attn(
1802
1818
}
1803
1819
1804
1820
kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
1805
- // cb(kq, "kq_soft_max_ext", il);
1806
-
1807
- // split cached v into n_head heads
1808
- struct ggml_tensor * v = ggml_cont (ctx0, ggml_transpose (ctx0, ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens)));
1809
1821
1810
- v = ggml_reshape_3d (ctx0, v, n_kv, n_embd_head_v, n_head_kv);
1811
- // cb(v, "v", il);
1822
+ if (!v_trans) {
1823
+ // note: avoid this branch
1824
+ v = ggml_cont (ctx0, ggml_transpose (ctx0, v));
1825
+ }
1812
1826
1813
1827
struct ggml_tensor * kqv = ggml_mul_mat (ctx0, v, kq);
1814
- // cb(kqv, "kqv", il);
1815
1828
1816
1829
struct ggml_tensor * kqv_merged = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1817
- // cb(kqv_merged, "kqv_merged", il);
1818
1830
1819
1831
cur = ggml_cont_2d (ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1820
- // cb(cur, "kqv_merged_cont", il);
1821
1832
1822
1833
if (!cparams.offload_kqv ) {
1823
1834
// all nodes between the KV store and the attention output are run on the CPU
@@ -1827,18 +1838,6 @@ ggml_tensor * llama_context::build_attn(
1827
1838
1828
1839
ggml_build_forward_expand (gf, cur);
1829
1840
1830
- if (wo) {
1831
- cur = build_lora_mm (ctx0, wo, cur);
1832
- }
1833
-
1834
- if (wo_b) {
1835
- // cb(cur, "kqv_wo", il);
1836
- }
1837
-
1838
- if (wo_b) {
1839
- cur = ggml_add (ctx0, cur, wo_b);
1840
- }
1841
-
1842
1841
return cur;
1843
1842
}
1844
1843
@@ -3274,13 +3273,10 @@ void llama_context_kv_self::build_attn_inp(
3274
3273
ggml_tensor * llama_context_kv_self::build_attn (
3275
3274
ggml_context * ctx0,
3276
3275
ggml_cgraph * gf,
3277
- ggml_tensor * wo,
3278
- ggml_tensor * wo_b,
3279
3276
ggml_tensor * q_cur,
3280
3277
ggml_tensor * k_cur,
3281
3278
ggml_tensor * v_cur,
3282
3279
ggml_tensor * kq_b,
3283
- int32_t n_tokens,
3284
3280
float kq_scale,
3285
3281
int il) {
3286
3282
const auto & hparams = model.hparams ;
@@ -3290,6 +3286,10 @@ ggml_tensor * llama_context_kv_self::build_attn(
3290
3286
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
3291
3287
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
3292
3288
3289
+ const auto n_tokens = q_cur->ne [2 ];
3290
+
3291
+ const bool v_trans = !cparams.flash_attn ;
3292
+
3293
3293
// store to KV cache
3294
3294
{
3295
3295
GGML_ASSERT (!kv_self.recurrent );
@@ -3308,7 +3308,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
3308
3308
3309
3309
struct ggml_tensor * v_cache_view = nullptr ;
3310
3310
3311
- if (cparams. flash_attn ) {
3311
+ if (!v_trans ) {
3312
3312
v_cache_view = ggml_view_1d (ctx0, kv_self.v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa)*kv_head);
3313
3313
} else {
3314
3314
// note: the V cache is transposed when not using flash attention
@@ -3351,117 +3351,35 @@ ggml_tensor * llama_context_kv_self::build_attn(
3351
3351
3352
3352
const auto n_kv = kv_self.n ;
3353
3353
3354
- const int64_t n_head = hparams.n_head (il);
3355
3354
const int64_t n_head_kv = hparams.n_head_kv (il);
3356
3355
3357
3356
const auto & n_embd_head_k = hparams.n_embd_head_k ;
3358
3357
const auto & n_embd_head_v = hparams.n_embd_head_v ;
3359
3358
3360
- struct ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
3359
+ ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
3361
3360
// cb(q, "q", il);
3362
3361
3363
- struct ggml_tensor * k =
3362
+ ggml_tensor * k =
3364
3363
ggml_view_3d (ctx0, kv_self.k_l [il],
3365
3364
n_embd_head_k, n_kv, n_head_kv,
3366
3365
ggml_row_size (kv_self.k_l [il]->type , n_embd_k_gqa),
3367
3366
ggml_row_size (kv_self.k_l [il]->type , n_embd_head_k),
3368
3367
0 );
3369
3368
// cb(k, "k", il);
3370
3369
3371
- struct ggml_tensor * cur;
3372
-
3373
- if (cparams.flash_attn ) {
3374
- GGML_UNUSED (model);
3375
- GGML_UNUSED (n_ctx);
3376
-
3377
- GGML_ASSERT (kq_b == nullptr );
3378
-
3379
- // split cached v into n_head heads (not transposed)
3380
- struct ggml_tensor * v =
3381
- ggml_view_3d (ctx0, kv_self.v_l [il],
3382
- n_embd_head_v, n_kv, n_head_kv,
3383
- ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3384
- ggml_row_size (kv_self.v_l [il]->type , n_embd_head_v),
3385
- 0 );
3386
- // cb(v, "v", il);
3387
-
3388
- cur = ggml_flash_attn_ext (ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias ,
3389
- hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0 .0f );
3390
-
3391
- ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
3392
-
3393
- cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
3394
- } else {
3395
- struct ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
3396
- // cb(kq, "kq", il);
3397
-
3398
- // note: this op tends to require high floating point range
3399
- // while for some models F16 is enough, for others it is not, so we default to F32 here
3400
- ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
3401
-
3402
- if (model.arch == LLM_ARCH_GROK) {
3403
- // need to do the following:
3404
- // multiply by attn_output_multiplyer of 0.08838834764831845
3405
- // and then :
3406
- // kq = 30 * tanh(kq / 30)
3407
- // before the softmax below
3408
-
3409
- kq = ggml_tanh (ctx0, ggml_scale (ctx0, kq, 0 .08838834764831845f /30 .0f ));
3410
- kq = ggml_scale (ctx0, kq, 30 );
3411
- }
3412
-
3413
- if (hparams.attn_soft_cap ) {
3414
- kq = ggml_scale (ctx0, kq, 1 .0f / hparams.f_attn_logit_softcapping );
3415
- kq = ggml_tanh (ctx0, kq);
3416
- kq = ggml_scale (ctx0, kq, hparams.f_attn_logit_softcapping );
3417
- }
3418
-
3419
- if (kq_b) {
3420
- kq = ggml_add (ctx0, kq, kq_b);
3421
- }
3422
-
3423
- kq = ggml_soft_max_ext (ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias );
3424
- // cb(kq, "kq_soft_max_ext", il);
3425
-
3426
- GGML_ASSERT (kv_self.size == n_ctx);
3427
-
3428
- // split cached v into n_head heads
3429
- struct ggml_tensor * v =
3430
- ggml_view_3d (ctx0, kv_self.v_l [il],
3431
- n_kv, n_embd_head_v, n_head_kv,
3432
- ggml_element_size (kv_self.v_l [il])*n_ctx,
3433
- ggml_element_size (kv_self.v_l [il])*n_ctx*n_embd_head_v,
3434
- 0 );
3435
- // cb(v, "v", il);
3436
-
3437
- struct ggml_tensor * kqv = ggml_mul_mat (ctx0, v, kq);
3438
- // cb(kqv, "kqv", il);
3439
-
3440
- struct ggml_tensor * kqv_merged = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
3441
- // cb(kqv_merged, "kqv_merged", il);
3442
-
3443
- cur = ggml_cont_2d (ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
3444
- // cb(cur, "kqv_merged_cont", il);
3445
-
3446
- if (!cparams.offload_kqv ) {
3447
- // all nodes between the KV store and the attention output are run on the CPU
3448
- ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend_cpu);
3449
- }
3450
- }
3451
-
3452
- ggml_build_forward_expand (gf, cur);
3453
-
3454
- if (wo) {
3455
- cur = build_lora_mm (ctx0, wo, cur);
3456
- }
3457
-
3458
- if (wo_b) {
3459
- // cb(cur, "kqv_wo", il);
3460
- }
3370
+ ggml_tensor * v = !v_trans ?
3371
+ ggml_view_3d (ctx0, kv_self.v_l [il],
3372
+ n_embd_head_v, n_kv, n_head_kv,
3373
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_v_gqa),
3374
+ ggml_row_size (kv_self.v_l [il]->type , n_embd_head_v),
3375
+ 0 ) :
3376
+ ggml_view_3d (ctx0, kv_self.v_l [il],
3377
+ n_kv, n_embd_head_v, n_head_kv,
3378
+ ggml_element_size (kv_self.v_l [il])*n_ctx,
3379
+ ggml_element_size (kv_self.v_l [il])*n_ctx*n_embd_head_v,
3380
+ 0 );
3461
3381
3462
- if (wo_b) {
3463
- cur = ggml_add (ctx0, cur, wo_b);
3464
- }
3382
+ struct ggml_tensor * cur = build_attn_mha (ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
3465
3383
3466
3384
return cur;
3467
3385
}
0 commit comments