Skip to content

Commit 3683408

Browse files
ggerganovMinh141120
authored andcommitted
model : more uniform output id handling (ggml-org#14275)
* model : more uniform output id handling ggml-ci * cont : revert n_outputs < n_tokens optimization ggml-ci * cont : fix out_ids initialization ggml-ci
1 parent 26439ad commit 3683408

File tree

1 file changed

+48
-50
lines changed

1 file changed

+48
-50
lines changed

src/llama-model.cpp

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6319,56 +6319,57 @@ struct llm_build_neo_bert : public llm_graph_context {
63196319

63206320
auto * inp_attn = build_attn_inp_no_cache();
63216321

6322-
// iterate layers
6322+
ggml_tensor * inp_out_ids = build_inp_out_ids();
6323+
63236324
for (int il = 0; il < n_layer; ++il) {
63246325
ggml_tensor * cur = inpL;
63256326

6326-
ggml_tensor * Qcur;
6327-
ggml_tensor * Kcur;
6328-
ggml_tensor * Vcur;
6329-
63306327
// pre-norm
63316328
cur = build_norm(inpL,
63326329
model.layers[il].attn_norm, NULL,
63336330
LLM_NORM_RMS, il);
63346331

6335-
// self-attention
6336-
cur = build_lora_mm(model.layers[il].wqkv, cur);
6337-
cb(cur, "wqkv", il);
6338-
6339-
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6340-
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6341-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6342-
6343-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6344-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6345-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6346-
6347-
// RoPE
6348-
Qcur = ggml_rope_ext(
6349-
ctx0, Qcur, inp_pos, nullptr,
6350-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6351-
ext_factor, attn_factor, beta_fast, beta_slow
6352-
);
6332+
{
6333+
ggml_tensor * Qcur;
6334+
ggml_tensor * Kcur;
6335+
ggml_tensor * Vcur;
63536336

6354-
Kcur = ggml_rope_ext(
6355-
ctx0, Kcur, inp_pos, nullptr,
6356-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6357-
ext_factor, attn_factor, beta_fast, beta_slow
6358-
);
6337+
// self-attention
6338+
cur = build_lora_mm(model.layers[il].wqkv, cur);
6339+
cb(cur, "wqkv", il);
6340+
6341+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
6342+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
6343+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
6344+
6345+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6346+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
6347+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
6348+
6349+
// RoPE
6350+
Qcur = ggml_rope_ext(
6351+
ctx0, Qcur, inp_pos, nullptr,
6352+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6353+
ext_factor, attn_factor, beta_fast, beta_slow
6354+
);
63596355

6360-
cb(Qcur, "Qcur", il);
6361-
cb(Kcur, "Kcur", il);
6362-
cb(Vcur, "Vcur", il);
6356+
Kcur = ggml_rope_ext(
6357+
ctx0, Kcur, inp_pos, nullptr,
6358+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
6359+
ext_factor, attn_factor, beta_fast, beta_slow
6360+
);
63636361

6364-
cur = build_attn(inp_attn, gf,
6365-
model.layers[il].wo, nullptr,
6366-
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6367-
cb(cur, "kqv_out", il);
6362+
cb(Qcur, "Qcur", il);
6363+
cb(Kcur, "Kcur", il);
6364+
cb(Vcur, "Vcur", il);
63686365

6369-
if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
6370-
// skip computing output for unused tokens
6371-
ggml_tensor * inp_out_ids = build_inp_out_ids();
6366+
cur = build_attn(inp_attn, gf,
6367+
model.layers[il].wo, nullptr,
6368+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
6369+
cb(cur, "kqv_out", il);
6370+
}
6371+
6372+
if (il == n_layer - 1 && inp_out_ids) {
63726373
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
63736374
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
63746375
}
@@ -7929,13 +7930,8 @@ struct llm_build_plamo : public llm_graph_context {
79297930
model.layers[il].wo, NULL,
79307931
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
79317932
}
7932-
ggml_tensor * sa_out = cur;
7933-
7934-
cur = attention_norm;
79357933

7936-
if (il == n_layer - 1) {
7937-
// skip computing output for unused tokens
7938-
ggml_tensor * inp_out_ids = build_inp_out_ids();
7934+
if (il == n_layer - 1 && inp_out_ids) {
79397935
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
79407936
sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids);
79417937
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
@@ -9638,6 +9634,8 @@ struct llm_build_mamba : public llm_graph_context {
96389634

96399635
auto * rs_inp = build_rs_inp();
96409636

9637+
ggml_tensor * inp_out_ids = build_inp_out_ids();
9638+
96419639
for (int il = 0; il < n_layer; ++il) {
96429640
// norm
96439641
cur = build_norm(inpL,
@@ -13995,6 +13993,8 @@ struct llm_build_dots1 : public llm_graph_context {
1399513993

1399613994
auto * inp_attn = build_attn_inp_kv_unified();
1399713995

13996+
ggml_tensor * inp_out_ids = build_inp_out_ids();
13997+
1399813998
for (int il = 0; il < n_layer; ++il) {
1399913999
ggml_tensor * inpSA = inpL;
1400014000

@@ -14047,9 +14047,7 @@ struct llm_build_dots1 : public llm_graph_context {
1404714047
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
1404814048
}
1404914049

14050-
if (il == n_layer - 1) {
14051-
// skip computing output for unused tokens
14052-
ggml_tensor * inp_out_ids = build_inp_out_ids();
14050+
if (il == n_layer - 1 && inp_out_ids) {
1405314051
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1405414052
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1405514053
}
@@ -14147,6 +14145,8 @@ struct llm_build_arcee : public llm_graph_context {
1414714145

1414814146
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
1414914147

14148+
ggml_tensor * inp_out_ids = build_inp_out_ids();
14149+
1415014150
for (int il = 0; il < n_layer; ++il) {
1415114151
ggml_tensor * inpSA = inpL;
1415214152

@@ -14209,9 +14209,7 @@ struct llm_build_arcee : public llm_graph_context {
1420914209
cb(cur, "attn_out", il);
1421014210
}
1421114211

14212-
if (il == n_layer - 1) {
14213-
// skip computing output for unused tokens
14214-
ggml_tensor * inp_out_ids = build_inp_out_ids();
14212+
if (il == n_layer - 1 && inp_out_ids) {
1421514213
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1421614214
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1421714215
}

0 commit comments

Comments
 (0)