@@ -252,90 +252,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
252
252
}
253
253
254
254
void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
255
- // TODO: repace this if with GGML_ASSERT(kq_mask)
256
- if (kq_mask) {
257
- if (cparams.causal_attn ) {
258
- const int64_t n_kv = ubatch->n_tokens ;
259
- const int64_t n_tokens = ubatch->n_tokens ;
260
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
261
- const int64_t n_seqs = ubatch->n_seqs ;
262
-
263
- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
264
- float * data = (float *) kq_mask->data ;
265
-
266
- for (int h = 0 ; h < 1 ; ++h) {
267
- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
268
- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
269
-
270
- for (int j = 0 ; j < n_seq_tokens; ++j) {
271
- const int32_t tj = s1*n_seq_tokens + j;
272
-
273
- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
274
- for (int i = 0 ; i < n_seq_tokens; ++i) {
275
- const int32_t ti = s0*n_seq_tokens + i;
276
- float f = -INFINITY;
277
-
278
- // TODO: fix indexing [UBATCH_IDX]
279
- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
280
- if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
281
- if (hparams.use_alibi ) {
282
- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
283
- } else {
284
- f = 0 .0f ;
285
- }
286
- break ;
287
- }
288
- }
289
-
290
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
291
- }
292
- }
293
- }
294
- }
295
- }
296
- } else {
297
- const int64_t n_tokens = ubatch->n_tokens ;
298
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
299
- const int64_t n_seqs = ubatch->n_seqs ;
300
- const int64_t n_stride = ubatch->n_tokens ;
301
-
302
- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
303
-
304
- float * data = (float *) kq_mask->data ;
305
-
306
- for (int h = 0 ; h < 1 ; ++h) {
307
- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
308
- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
309
-
310
- for (int j = 0 ; j < n_seq_tokens; ++j) {
311
- const int32_t tj = s1*n_seq_tokens + j;
312
-
313
- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
314
- for (int i = 0 ; i < n_seq_tokens; ++i) {
315
- const int32_t ti = s0*n_seq_tokens + i;
316
- float f = -INFINITY;
317
-
318
- // TODO: fix indexing [UBATCH_IDX]
319
- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
320
- if (ubatch->seq_id [s0][s] == seq_id) {
321
- if (hparams.use_alibi ) {
322
- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
323
- } else {
324
- f = 0 .0f ;
325
- }
326
- break ;
327
- }
328
- }
329
-
330
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
331
- }
332
- }
255
+ const int64_t n_kv = ubatch->n_tokens ;
256
+ const int64_t n_tokens = ubatch->n_tokens ;
257
+
258
+ GGML_ASSERT (kq_mask);
259
+ GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
333
260
334
- for (int i = n_tokens; i < n_stride; ++i) {
335
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
261
+ float * data = (float *) kq_mask->data ;
262
+
263
+ for (int h = 0 ; h < 1 ; ++h) {
264
+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
265
+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
266
+
267
+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
268
+ float f = -INFINITY;
269
+
270
+ for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
271
+ const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
272
+
273
+ // TODO: reimplement this like in llama_kv_cache_unified
274
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos [i0] <= ubatch->pos [i1])) {
275
+ if (hparams.use_alibi ) {
276
+ f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
277
+ } else {
278
+ f = 0 .0f ;
336
279
}
280
+ break ;
337
281
}
338
282
}
283
+
284
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
339
285
}
340
286
}
341
287
}
@@ -358,34 +304,36 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
358
304
}
359
305
360
306
void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
361
- if (cross_kq_mask) {
362
- const int64_t n_enc = cross_kq_mask->ne [0 ];
363
- const int64_t n_tokens = ubatch->n_tokens ;
307
+ GGML_ASSERT (cross_kq_mask);
364
308
365
- GGML_ASSERT ( ggml_backend_buffer_is_host ( cross_kq_mask->buffer )) ;
366
- GGML_ASSERT (! ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
309
+ const int64_t n_enc = cross_kq_mask->ne [ 0 ] ;
310
+ const int64_t n_tokens = ubatch->n_tokens ;
367
311
368
- float * data = (float *) cross_kq_mask->data ;
312
+ GGML_ASSERT (ggml_backend_buffer_is_host (cross_kq_mask->buffer ));
313
+ GGML_ASSERT (!ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
369
314
370
- for (int h = 0 ; h < 1 ; ++h) {
371
- for (int j = 0 ; j < n_tokens; ++j) {
372
- for (int i = 0 ; i < n_enc; ++i) {
373
- float f = -INFINITY;
374
- // TODO: fix indexing [UBATCH_IDX]
375
- for (int s = 0 ; s < ubatch->n_seq_id [j]; ++s) {
376
- const llama_seq_id seq_id = ubatch->seq_id [j][s];
377
- if (cross->seq_ids_enc [i].find (seq_id) != cross->seq_ids_enc [i].end ()) {
378
- f = 0 .0f ;
379
- }
315
+ float * data = (float *) cross_kq_mask->data ;
316
+
317
+ for (int h = 0 ; h < 1 ; ++h) {
318
+ for (int i = 0 ; i < n_tokens; ++i) {
319
+ for (int j = 0 ; j < n_enc; ++j) {
320
+ float f = -INFINITY;
321
+
322
+ for (int s = 0 ; s < ubatch->n_seq_id [i]; ++s) {
323
+ const llama_seq_id seq_id = ubatch->seq_id [i][s];
324
+
325
+ if (cross->seq_ids_enc [j].find (seq_id) != cross->seq_ids_enc [j].end ()) {
326
+ f = 0 .0f ;
380
327
}
381
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
382
328
}
329
+
330
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
383
331
}
332
+ }
384
333
385
- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
386
- for (int j = 0 ; j < n_enc; ++j) {
387
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
388
- }
334
+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
335
+ for (int j = 0 ; j < n_enc; ++j) {
336
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
389
337
}
390
338
}
391
339
}
0 commit comments