Skip to content

Commit d9b6910

Browse files
authored
kv-cache : opt mask set input (#14600)
ggml-ci
1 parent ad57d3e commit d9b6910

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
12831283
const int64_t n_tps = n_tokens/n_stream;
12841284
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
12851285

1286+
std::fill(data, data + ggml_nelements(dst), -INFINITY);
1287+
12861288
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
12871289
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12881290
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
13061308

13071309
const llama_pos p1 = ubatch->pos[i];
13081310

1309-
for (uint32_t j = 0; j < n_kv; ++j) {
1310-
float f = 0.0f;
1311-
1312-
bool masked = false;
1311+
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
13131312

1313+
for (uint32_t j = 0; j < n_kv; ++j) {
13141314
if (cells.is_empty(j)) {
1315-
masked = true;
1316-
} else {
1317-
const llama_pos p0 = cells.pos_get(j);
1318-
1319-
// mask the token if not the same sequence
1320-
masked = masked || (!cells.seq_has(j, seq_id));
1315+
continue;
1316+
}
13211317

1322-
// mask future tokens
1323-
masked = masked || (causal_attn && p0 > p1);
1318+
// mask the token if not the same sequence
1319+
if (!cells.seq_has(j, seq_id)) {
1320+
continue;
1321+
}
13241322

1325-
// apply SWA if any
1326-
masked = masked || (is_masked_swa(p0, p1));
1323+
const llama_pos p0 = cells.pos_get(j);
13271324

1328-
if (!masked && hparams.use_alibi) {
1329-
f = -std::abs(p0 - p1);
1330-
}
1325+
// mask future tokens
1326+
if (causal_attn && p0 > p1) {
1327+
continue;
13311328
}
13321329

1333-
if (masked) {
1334-
f = -INFINITY;
1330+
// apply SWA if any
1331+
if (is_masked_swa(p0, p1)) {
1332+
continue;
13351333
}
13361334

1337-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
1338-
}
1339-
1340-
// mask padded tokens
1341-
if (data) {
1342-
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
1343-
for (uint32_t j = 0; j < n_kv; ++j) {
1344-
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
1345-
}
1346-
}
1335+
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
13471336
}
13481337
}
13491338
}

0 commit comments

Comments
 (0)