Skip to content

Commit c6b84e2

Browse files
committed
fix modern-bert swa logic
1 parent ad2a19a commit c6b84e2

File tree

5 files changed

+18
-83
lines changed

5 files changed

+18
-83
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -279,60 +279,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
279279
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
280280
if (kq_mask) {
281281
// Check if we're using sliding window attention
282-
if (n_swa > 0) {
283-
const int64_t n_tokens = ubatch->n_tokens;
284-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
285-
const int64_t n_seqs = ubatch->n_seqs;
286-
const int64_t n_stride = ubatch->n_tokens;
287-
const int64_t half_n_swa = n_swa / 2;
288-
289-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
290-
float * data = (float *) kq_mask->data;
291-
292-
// Implement symmetric sliding window attention
293-
// token i attends to tokens [i - n_swa/2, i + n_swa/2]
294-
for (int h = 0; h < 1; ++h) {
295-
for (int s1 = 0; s1 < n_seqs; ++s1) {
296-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
297-
298-
for (int j = 0; j < n_seq_tokens; ++j) {
299-
const int32_t tj = s1*n_seq_tokens + j;
300-
const int64_t pos_j = ubatch->pos[tj];
301-
302-
for (int s0 = 0; s0 < n_seqs; ++s0) {
303-
for (int i = 0; i < n_seq_tokens; ++i) {
304-
const int32_t ti = s0*n_seq_tokens + i;
305-
float f = -INFINITY;
306-
307-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
308-
if (ubatch->seq_id[s0][s] == seq_id) {
309-
const int64_t pos_i = ubatch->pos[ti];
310-
const int64_t pos_diff = pos_j - pos_i;
311-
312-
// Apply sliding window constraint
313-
// [i - n_swa/2, i + n_swa/2]
314-
if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
315-
if (hparams.use_alibi) {
316-
f = -std::abs(pos_diff);
317-
} else {
318-
f = 0.0f;
319-
}
320-
}
321-
break;
322-
}
323-
}
324-
325-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
326-
}
327-
}
328-
329-
for (int i = n_tokens; i < n_stride; ++i) {
330-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
331-
}
332-
}
333-
}
334-
}
335-
} else if (cparams.causal_attn) {
282+
if (cparams.causal_attn) {
336283
const int64_t n_kv = ubatch->n_tokens;
337284
const int64_t n_tokens = ubatch->n_tokens;
338285
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -375,6 +322,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
375322
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
376323
const int64_t n_seqs = ubatch->n_seqs;
377324
const int64_t n_stride = ubatch->n_tokens;
325+
const int64_t half_n_swa = hparams.n_swa / 2;
378326

379327
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
380328

@@ -386,6 +334,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
386334

387335
for (int j = 0; j < n_seq_tokens; ++j) {
388336
const int32_t tj = s1*n_seq_tokens + j;
337+
const int64_t pos_j = ubatch->pos[tj];
389338

390339
for (int s0 = 0; s0 < n_seqs; ++s0) {
391340
for (int i = 0; i < n_seq_tokens; ++i) {
@@ -394,7 +343,11 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
394343

395344
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
396345
if (ubatch->seq_id[s0][s] == seq_id) {
397-
if (hparams.use_alibi) {
346+
const int64_t pos_i = ubatch->pos[ti];
347+
const int64_t pos_diff = pos_j - pos_i;
348+
349+
if (hparams.use_alibi &&
350+
(pos_diff >= -half_n_swa && pos_diff <= half_n_swa)) {
398351
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
399352
} else {
400353
f = 0.0f;
@@ -1242,22 +1195,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
12421195
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
12431196
}
12441197

1245-
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa() const {
1246-
// Use the sliding window size from hyperparameters
1247-
// If hparams.n_swa is 0, use a default value (128)
1248-
const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128;
1249-
1250-
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa);
1251-
1252-
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1253-
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1254-
ggml_set_input(inp->kq_mask);
1255-
1256-
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1257-
1258-
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1259-
}
1260-
12611198
ggml_tensor * llm_graph_context::build_attn(
12621199
llm_graph_input_attn_no_cache * inp,
12631200
ggml_cgraph * gf,

src/llama-graph.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,7 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
228228
public:
229229
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
230230
hparams(hparams),
231-
cparams(cparams),
232-
n_swa(0) {
233-
}
234-
235-
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams, int n_swa) :
236-
hparams(hparams),
237-
cparams(cparams),
238-
n_swa(n_swa) {
231+
cparams(cparams) {
239232
}
240233

241234
~llm_graph_input_attn_no_cache() = default;
@@ -249,7 +242,6 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
249242

250243
const llama_hparams & hparams;
251244
const llama_cparams & cparams;
252-
const int n_swa; // Sliding window attention size (0 = disabled)
253245
};
254246

255247
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
@@ -551,7 +543,6 @@ struct llm_graph_context {
551543
float kq_scale) const;
552544

553545
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
554-
llm_graph_input_attn_no_cache * build_attn_inp_no_cache_iswa() const;
555546

556547
ggml_tensor * build_attn(
557548
llm_graph_input_attn_no_cache * inp,

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ enum llama_swa_type {
1818
LLAMA_SWA_TYPE_NONE = 0,
1919
LLAMA_SWA_TYPE_STANDARD = 1,
2020
LLAMA_SWA_TYPE_CHUNKED = 2,
21+
LLAMA_SWA_TYPE_SYMMETRIC = 3,
2122
};
2223

2324
struct llama_hparams_posnet {

src/llama-kv-cache-unified.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,12 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
12651265
return true;
12661266
}
12671267
} break;
1268+
case LLAMA_SWA_TYPE_SYMMETRIC:
1269+
{
1270+
if ( p1 - p0 <= (int32_t) n_swa / 2 || p0 - p1 >= (int32_t) n_swa / 2) {
1271+
return true;
1272+
}
1273+
} break;
12681274
}
12691275

12701276
return false;

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
718718
hparams.rope_freq_base_train_swa = 10000.0f;
719719
hparams.n_swa = 128;
720720

721-
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
721+
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
722722
hparams.set_swa_pattern(3, 0);
723723

724724
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -6144,7 +6144,7 @@ struct llm_build_modern_bert : public llm_graph_context {
61446144
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
61456145
cb(inpL, "inp_norm", -1);
61466146

6147-
auto * inp_attn = build_attn_inp_no_cache_iswa();
6147+
auto * inp_attn = build_attn_inp_no_cache();
61486148

61496149
// iterate layers
61506150
for (int il = 0; il < n_layer; ++il) {

0 commit comments

Comments
 (0)