@@ -6319,57 +6319,56 @@ struct llm_build_neo_bert : public llm_graph_context {
6319
6319
6320
6320
auto * inp_attn = build_attn_inp_no_cache();
6321
6321
6322
- ggml_tensor * inp_out_ids = build_inp_out_ids();
6323
-
6322
+ // iterate layers
6324
6323
for (int il = 0; il < n_layer; ++il) {
6325
6324
ggml_tensor * cur = inpL;
6326
6325
6326
+ ggml_tensor * Qcur;
6327
+ ggml_tensor * Kcur;
6328
+ ggml_tensor * Vcur;
6329
+
6327
6330
// pre-norm
6328
6331
cur = build_norm(inpL,
6329
6332
model.layers[il].attn_norm, NULL,
6330
6333
LLM_NORM_RMS, il);
6331
6334
6332
- {
6333
- ggml_tensor * Qcur;
6334
- ggml_tensor * Kcur;
6335
- ggml_tensor * Vcur;
6336
-
6337
- // self-attention
6338
- cur = build_lora_mm(model.layers[il].wqkv, cur);
6339
- cb(cur, "wqkv", il);
6340
-
6341
- Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6342
- Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6343
- Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6344
-
6345
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6346
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6347
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6348
-
6349
- // RoPE
6350
- Qcur = ggml_rope_ext(
6351
- ctx0, Qcur, inp_pos, nullptr,
6352
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6353
- ext_factor, attn_factor, beta_fast, beta_slow
6354
- );
6335
+ // self-attention
6336
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
6337
+ cb(cur, "wqkv", il);
6338
+
6339
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6340
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6341
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6342
+
6343
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6344
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6345
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6346
+
6347
+ // RoPE
6348
+ Qcur = ggml_rope_ext(
6349
+ ctx0, Qcur, inp_pos, nullptr,
6350
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6351
+ ext_factor, attn_factor, beta_fast, beta_slow
6352
+ );
6355
6353
6356
- Kcur = ggml_rope_ext(
6357
- ctx0, Kcur, inp_pos, nullptr,
6358
- n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6359
- ext_factor, attn_factor, beta_fast, beta_slow
6360
- );
6354
+ Kcur = ggml_rope_ext(
6355
+ ctx0, Kcur, inp_pos, nullptr,
6356
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6357
+ ext_factor, attn_factor, beta_fast, beta_slow
6358
+ );
6361
6359
6362
- cb(Qcur, "Qcur", il);
6363
- cb(Kcur, "Kcur", il);
6364
- cb(Vcur, "Vcur", il);
6360
+ cb(Qcur, "Qcur", il);
6361
+ cb(Kcur, "Kcur", il);
6362
+ cb(Vcur, "Vcur", il);
6365
6363
6366
- cur = build_attn(inp_attn, gf,
6367
- model.layers[il].wo, nullptr,
6368
- Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6369
- cb(cur, "kqv_out", il);
6370
- }
6364
+ cur = build_attn(inp_attn, gf,
6365
+ model.layers[il].wo, nullptr,
6366
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6367
+ cb(cur, "kqv_out", il);
6371
6368
6372
- if (il == n_layer - 1 && inp_out_ids) {
6369
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6370
+ // skip computing output for unused tokens
6371
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
6373
6372
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
6374
6373
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
6375
6374
}
@@ -14798,6 +14797,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
14798
14797
case LLM_ARCH_GRANITE_MOE:
14799
14798
case LLM_ARCH_CHAMELEON:
14800
14799
case LLM_ARCH_BAILINGMOE:
14800
+ case LLM_ARCH_NEO_BERT:
14801
14801
case LLM_ARCH_ARCEE:
14802
14802
return LLAMA_ROPE_TYPE_NORM;
14803
14803
0 commit comments