Skip to content

Commit 28dec76

Browse files
committed
cont : uniform ubatch indexing
ggml-ci
1 parent e6ac4ac commit 28dec76

File tree

2 files changed

+64
-116
lines changed

2 files changed

+64
-116
lines changed

src/llama-graph.cpp

Lines changed: 49 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -252,90 +252,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
252252
}
253253

254254
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
255-
// TODO: repace this if with GGML_ASSERT(kq_mask)
256-
if (kq_mask) {
257-
if (cparams.causal_attn) {
258-
const int64_t n_kv = ubatch->n_tokens;
259-
const int64_t n_tokens = ubatch->n_tokens;
260-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
261-
const int64_t n_seqs = ubatch->n_seqs;
262-
263-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
264-
float * data = (float *) kq_mask->data;
265-
266-
for (int h = 0; h < 1; ++h) {
267-
for (int s1 = 0; s1 < n_seqs; ++s1) {
268-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
269-
270-
for (int j = 0; j < n_seq_tokens; ++j) {
271-
const int32_t tj = s1*n_seq_tokens + j;
272-
273-
for (int s0 = 0; s0 < n_seqs; ++s0) {
274-
for (int i = 0; i < n_seq_tokens; ++i) {
275-
const int32_t ti = s0*n_seq_tokens + i;
276-
float f = -INFINITY;
277-
278-
// TODO: fix indexing [UBATCH_IDX]
279-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
280-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
281-
if (hparams.use_alibi) {
282-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
283-
} else {
284-
f = 0.0f;
285-
}
286-
break;
287-
}
288-
}
289-
290-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
291-
}
292-
}
293-
}
294-
}
295-
}
296-
} else {
297-
const int64_t n_tokens = ubatch->n_tokens;
298-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
299-
const int64_t n_seqs = ubatch->n_seqs;
300-
const int64_t n_stride = ubatch->n_tokens;
301-
302-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
303-
304-
float * data = (float *) kq_mask->data;
305-
306-
for (int h = 0; h < 1; ++h) {
307-
for (int s1 = 0; s1 < n_seqs; ++s1) {
308-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
309-
310-
for (int j = 0; j < n_seq_tokens; ++j) {
311-
const int32_t tj = s1*n_seq_tokens + j;
312-
313-
for (int s0 = 0; s0 < n_seqs; ++s0) {
314-
for (int i = 0; i < n_seq_tokens; ++i) {
315-
const int32_t ti = s0*n_seq_tokens + i;
316-
float f = -INFINITY;
317-
318-
// TODO: fix indexing [UBATCH_IDX]
319-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
320-
if (ubatch->seq_id[s0][s] == seq_id) {
321-
if (hparams.use_alibi) {
322-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
323-
} else {
324-
f = 0.0f;
325-
}
326-
break;
327-
}
328-
}
329-
330-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
331-
}
332-
}
255+
const int64_t n_kv = ubatch->n_tokens;
256+
const int64_t n_tokens = ubatch->n_tokens;
257+
258+
GGML_ASSERT(kq_mask);
259+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
333260

334-
for (int i = n_tokens; i < n_stride; ++i) {
335-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
261+
float * data = (float *) kq_mask->data;
262+
263+
for (int h = 0; h < 1; ++h) {
264+
for (int i1 = 0; i1 < n_tokens; ++i1) {
265+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
266+
267+
for (int i0 = 0; i0 < n_tokens; ++i0) {
268+
float f = -INFINITY;
269+
270+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
271+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
272+
273+
// TODO: reimplement this like in llama_kv_cache_unified
274+
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
275+
if (hparams.use_alibi) {
276+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
277+
} else {
278+
f = 0.0f;
336279
}
280+
break;
337281
}
338282
}
283+
284+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
339285
}
340286
}
341287
}
@@ -358,34 +304,36 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
358304
}
359305

360306
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
361-
if (cross_kq_mask) {
362-
const int64_t n_enc = cross_kq_mask->ne[0];
363-
const int64_t n_tokens = ubatch->n_tokens;
307+
GGML_ASSERT(cross_kq_mask);
364308

365-
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
366-
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
309+
const int64_t n_enc = cross_kq_mask->ne[0];
310+
const int64_t n_tokens = ubatch->n_tokens;
367311

368-
float * data = (float *) cross_kq_mask->data;
312+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
313+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
369314

370-
for (int h = 0; h < 1; ++h) {
371-
for (int j = 0; j < n_tokens; ++j) {
372-
for (int i = 0; i < n_enc; ++i) {
373-
float f = -INFINITY;
374-
// TODO: fix indexing [UBATCH_IDX]
375-
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
376-
const llama_seq_id seq_id = ubatch->seq_id[j][s];
377-
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
378-
f = 0.0f;
379-
}
315+
float * data = (float *) cross_kq_mask->data;
316+
317+
for (int h = 0; h < 1; ++h) {
318+
for (int i = 0; i < n_tokens; ++i) {
319+
for (int j = 0; j < n_enc; ++j) {
320+
float f = -INFINITY;
321+
322+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
323+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
324+
325+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
326+
f = 0.0f;
380327
}
381-
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
382328
}
329+
330+
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
383331
}
332+
}
384333

385-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
386-
for (int j = 0; j < n_enc; ++j) {
387-
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
388-
}
334+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
335+
for (int j = 0; j < n_enc; ++j) {
336+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
389337
}
390338
}
391339
}

src/llama-kv-cache-unified.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -814,23 +814,23 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
814814
// xxxxx-----
815815
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
816816
for (uint32_t h = 0; h < 1; ++h) {
817-
for (uint32_t b = 0; b < n_tokens; ++b) {
818-
const llama_seq_id seq_id = ubatch->seq_id[b][0];
817+
for (uint32_t i = 0; i < n_tokens; ++i) {
818+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
819819

820-
const llama_pos p1 = ubatch->pos[b];
820+
const llama_pos p1 = ubatch->pos[i];
821821

822-
for (uint32_t i = 0; i < n_kv; ++i) {
822+
for (uint32_t j = 0; j < n_kv; ++j) {
823823
float f = 0.0f;
824824

825825
bool masked = false;
826826

827-
if (cells.is_empty(i)) {
827+
if (cells.is_empty(j)) {
828828
masked = true;
829829
} else {
830-
const llama_pos p0 = cells.pos_get(i);
830+
const llama_pos p0 = cells.pos_get(j);
831831

832832
// mask the token if not the same sequence
833-
masked = masked || (!cells.seq_has(i, seq_id));
833+
masked = masked || (!cells.seq_has(j, seq_id));
834834

835835
// mask future tokens
836836
masked = masked || (causal_attn && p0 > p1);
@@ -847,15 +847,15 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
847847
f = -INFINITY;
848848
}
849849

850-
data[h*(n_kv*n_tokens) + b*n_kv + i] = f;
850+
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
851851
}
852852
}
853853

854854
// mask padded tokens
855855
if (data) {
856-
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
857-
for (uint32_t i = 0; i < n_kv; ++i) {
858-
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
856+
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
857+
for (uint32_t j = 0; j < n_kv; ++j) {
858+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
859859
}
860860
}
861861
}
@@ -883,12 +883,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
883883
const int32_t n_kv = dst->ne[0];
884884

885885
for (int h = 0; h < 1; ++h) {
886-
for (int j = 0; j < n_tokens; ++j) {
887-
for (int i = 0; i < n_kv; ++i) {
886+
for (int i = 0; i < n_tokens; ++i) {
887+
for (int j = 0; j < n_kv; ++j) {
888888
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
889-
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
889+
const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
890890

891-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
891+
data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
892892
}
893893
}
894894
}

0 commit comments

Comments
 (0)