@@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
1283
1283
const int64_t n_tps = n_tokens/n_stream;
1284
1284
const int64_t n_tps_pad = GGML_PAD (n_tps, GGML_KQ_MASK_PAD);
1285
1285
1286
+ std::fill (data, data + ggml_nelements (dst), -INFINITY);
1287
+
1286
1288
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
1287
1289
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
1288
1290
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
1306
1308
1307
1309
const llama_pos p1 = ubatch->pos [i];
1308
1310
1309
- for (uint32_t j = 0 ; j < n_kv; ++j) {
1310
- float f = 0 .0f ;
1311
-
1312
- bool masked = false ;
1311
+ const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
1313
1312
1313
+ for (uint32_t j = 0 ; j < n_kv; ++j) {
1314
1314
if (cells.is_empty (j)) {
1315
- masked = true ;
1316
- } else {
1317
- const llama_pos p0 = cells.pos_get (j);
1318
-
1319
- // mask the token if not the same sequence
1320
- masked = masked || (!cells.seq_has (j, seq_id));
1315
+ continue ;
1316
+ }
1321
1317
1322
- // mask future tokens
1323
- masked = masked || (causal_attn && p0 > p1);
1318
+ // mask the token if not the same sequence
1319
+ if (!cells.seq_has (j, seq_id)) {
1320
+ continue ;
1321
+ }
1324
1322
1325
- // apply SWA if any
1326
- masked = masked || (is_masked_swa (p0, p1));
1323
+ const llama_pos p0 = cells.pos_get (j);
1327
1324
1328
- if (!masked && hparams. use_alibi ) {
1329
- f = - std::abs ( p0 - p1);
1330
- }
1325
+ // mask future tokens
1326
+ if (causal_attn && p0 > p1) {
1327
+ continue ;
1331
1328
}
1332
1329
1333
- if (masked) {
1334
- f = -INFINITY;
1330
+ // apply SWA if any
1331
+ if (is_masked_swa (p0, p1)) {
1332
+ continue ;
1335
1333
}
1336
1334
1337
- data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1338
- }
1339
-
1340
- // mask padded tokens
1341
- if (data) {
1342
- for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
1343
- for (uint32_t j = 0 ; j < n_kv; ++j) {
1344
- data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1345
- }
1346
- }
1335
+ data[idst + j] = hparams.use_alibi ? -std::abs (p0 - p1) : 0 .0f ;
1347
1336
}
1348
1337
}
1349
1338
}
0 commit comments