Skip to content

Commit f7c7a92

Browse files
compiladeCISC
andcommitted
model : use ggml_swiglu_split for Mamba
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 2f39cd7 commit f7c7a92

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10057,7 +10057,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1005710057
// TODO: skip computing output earlier for unused tokens
1005810058

1005910059
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
10060-
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
10060+
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1006110061

1006210062
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
1006310063
cur = build_lora_mm(layer.ssm_out, y);
@@ -10181,7 +10181,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1018110181
// TODO: skip computing output earlier for unused tokens
1018210182

1018310183
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
10184-
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
10184+
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
1018510185

1018610186
// grouped RMS norm
1018710187
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);

0 commit comments

Comments
 (0)