@@ -279,60 +279,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
279
279
void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
280
280
if (kq_mask) {
281
281
// Check if we're using sliding window attention
282
- if (n_swa > 0 ) {
283
- const int64_t n_tokens = ubatch->n_tokens ;
284
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
285
- const int64_t n_seqs = ubatch->n_seqs ;
286
- const int64_t n_stride = ubatch->n_tokens ;
287
- const int64_t half_n_swa = n_swa / 2 ;
288
-
289
- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
290
- float * data = (float *) kq_mask->data ;
291
-
292
- // Implement symmetric sliding window attention
293
- // token i attends to tokens [i - n_swa/2, i + n_swa/2]
294
- for (int h = 0 ; h < 1 ; ++h) {
295
- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
296
- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
297
-
298
- for (int j = 0 ; j < n_seq_tokens; ++j) {
299
- const int32_t tj = s1*n_seq_tokens + j;
300
- const int64_t pos_j = ubatch->pos [tj];
301
-
302
- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
303
- for (int i = 0 ; i < n_seq_tokens; ++i) {
304
- const int32_t ti = s0*n_seq_tokens + i;
305
- float f = -INFINITY;
306
-
307
- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
308
- if (ubatch->seq_id [s0][s] == seq_id) {
309
- const int64_t pos_i = ubatch->pos [ti];
310
- const int64_t pos_diff = pos_j - pos_i;
311
-
312
- // Apply sliding window constraint
313
- // [i - n_swa/2, i + n_swa/2]
314
- if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
315
- if (hparams.use_alibi ) {
316
- f = -std::abs (pos_diff);
317
- } else {
318
- f = 0 .0f ;
319
- }
320
- }
321
- break ;
322
- }
323
- }
324
-
325
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
326
- }
327
- }
328
-
329
- for (int i = n_tokens; i < n_stride; ++i) {
330
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
331
- }
332
- }
333
- }
334
- }
335
- } else if (cparams.causal_attn ) {
282
+ if (cparams.causal_attn ) {
336
283
const int64_t n_kv = ubatch->n_tokens ;
337
284
const int64_t n_tokens = ubatch->n_tokens ;
338
285
const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -375,6 +322,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
375
322
const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
376
323
const int64_t n_seqs = ubatch->n_seqs ;
377
324
const int64_t n_stride = ubatch->n_tokens ;
325
+ const int64_t half_n_swa = hparams.n_swa / 2 ;
378
326
379
327
GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
380
328
@@ -386,6 +334,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
386
334
387
335
for (int j = 0 ; j < n_seq_tokens; ++j) {
388
336
const int32_t tj = s1*n_seq_tokens + j;
337
+ const int64_t pos_j = ubatch->pos [tj];
389
338
390
339
for (int s0 = 0 ; s0 < n_seqs; ++s0) {
391
340
for (int i = 0 ; i < n_seq_tokens; ++i) {
@@ -394,7 +343,11 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
394
343
395
344
for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
396
345
if (ubatch->seq_id [s0][s] == seq_id) {
397
- if (hparams.use_alibi ) {
346
+ const int64_t pos_i = ubatch->pos [ti];
347
+ const int64_t pos_diff = pos_j - pos_i;
348
+
349
+ if (hparams.use_alibi &&
350
+ (pos_diff >= -half_n_swa && pos_diff <= half_n_swa)) {
398
351
f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
399
352
} else {
400
353
f = 0 .0f ;
@@ -1242,22 +1195,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1242
1195
return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
1243
1196
}
1244
1197
1245
- llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa () const {
1246
- // Use the sliding window size from hyperparameters
1247
- // If hparams.n_swa is 0, use a default value (128)
1248
- const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128 ;
1249
-
1250
- auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa);
1251
-
1252
- // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1253
- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1254
- ggml_set_input (inp->kq_mask );
1255
-
1256
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1257
-
1258
- return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
1259
- }
1260
-
1261
1198
ggml_tensor * llm_graph_context::build_attn (
1262
1199
llm_graph_input_attn_no_cache * inp,
1263
1200
ggml_cgraph * gf,
0 commit comments