@@ -5163,7 +5163,10 @@ void llama_model::print_info() const {
5163
5163
}
5164
5164
}
5165
5165
5166
- if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) {
5166
+ if (arch == LLM_ARCH_MAMBA ||
5167
+ arch == LLM_ARCH_MAMBA2 ||
5168
+ arch == LLM_ARCH_JAMBA ||
5169
+ arch == LLM_ARCH_FALCON_H1) {
5167
5170
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
5168
5171
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
5169
5172
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@@ -10436,8 +10439,11 @@ struct llm_graph_context_mamba : public virtual llm_graph_context {
10436
10439
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
10437
10440
10438
10441
// grouped RMS norm
10439
- y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10440
- y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10442
+ if (model.layers[il].ssm_norm) {
10443
+ y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
10444
+ y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
10445
+ }
10446
+
10441
10447
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
10442
10448
10443
10449
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@@ -15180,10 +15186,9 @@ struct llm_build_ernie4_5 : public llm_graph_context {
15180
15186
}
15181
15187
};
15182
15188
15183
- struct llm_build_falcon_h1 : public llm_graph_context {
15184
- const llama_model & model;
15185
-
15186
- llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
15189
+ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
15190
+ llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
15191
+ : llm_graph_context(params), llm_graph_context_mamba(params) {
15187
15192
const int64_t n_embd_head = hparams.n_embd_head_v;
15188
15193
15189
15194
ggml_tensor * cur;
@@ -15250,7 +15255,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
15250
15255
// Mamba2 layer
15251
15256
cb(cur, "ssm_in", il);
15252
15257
15253
- ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
15258
+ ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr() , gf, cur, model , ubatch, il);
15254
15259
cb(ssm_out, "ssm_out", il);
15255
15260
15256
15261
// // Aggregation
@@ -15306,139 +15311,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
15306
15311
15307
15312
ggml_build_forward_expand(gf, cur);
15308
15313
}
15309
-
15310
- ggml_tensor * build_mamba2_layer(
15311
- llm_graph_input_mem_hybrid * inp,
15312
- ggml_cgraph * gf,
15313
- ggml_tensor * cur,
15314
- const llama_ubatch & ubatch,
15315
- int il) const {
15316
- const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
15317
-
15318
- const auto kv_head = kv_state->get_head();
15319
-
15320
- const int64_t d_conv = hparams.ssm_d_conv;
15321
- const int64_t d_inner = hparams.ssm_d_inner;
15322
- const int64_t d_state = hparams.ssm_d_state;
15323
- const int64_t n_head = hparams.ssm_dt_rank;
15324
- const int64_t head_dim = d_inner / n_head;
15325
- const int64_t n_group = hparams.ssm_n_group;
15326
- const int64_t n_seqs = ubatch.n_seqs;
15327
-
15328
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
15329
-
15330
- GGML_ASSERT(n_seqs != 0);
15331
- GGML_ASSERT(ubatch.equal_seqs);
15332
- GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
15333
-
15334
- ggml_tensor * conv_states_all = kv_state->get_r_l(il);
15335
- ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
15336
-
15337
- ggml_tensor * conv = build_rs(inp->get_recr(), gf, conv_states_all, hparams.n_embd_r(), n_seqs);
15338
- conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
15339
-
15340
- // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
15341
- cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
15342
-
15343
- // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
15344
-
15345
- // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
15346
- ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
15347
- cb(zxBCdt, "zxBCdt", il);
15348
-
15349
- // split the above in three
15350
- ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
15351
- ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
15352
- ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
15353
-
15354
- // conv
15355
- {
15356
- // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
15357
- ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
15358
-
15359
- // copy last (d_conv - 1) columns back into the state cache
15360
- ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
15361
-
15362
- ggml_build_forward_expand(gf,
15363
- ggml_cpy(ctx0, last_conv,
15364
- ggml_view_1d(ctx0, conv_states_all,
15365
- (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
15366
- kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
15367
-
15368
- // 1D convolution
15369
- // The equivalent is to make a self-overlapping view of conv_x
15370
- // over d_conv columns at each stride in the 3rd dimension,
15371
- // then element-wise multiply that with the conv1d weight,
15372
- // then sum the elements of each row,
15373
- // (the last two steps are a dot product over rows (also doable with mul_mat))
15374
- // then permute away the ne[0] dimension,
15375
- // and then you're left with the resulting x tensor.
15376
- // For simultaneous sequences, all sequences need to have the same length.
15377
- xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
15378
-
15379
- // bias
15380
- xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
15381
-
15382
- xBC = ggml_silu(ctx0, xBC);
15383
- }
15384
-
15385
- // ssm
15386
- {
15387
- // These correspond to V K Q in SSM/attention duality
15388
- ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
15389
-
15390
- ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
15391
-
15392
- ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
15393
-
15394
- // {n_head, n_seq_tokens, n_seqs}
15395
- dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
15396
-
15397
- ggml_tensor * A = model.layers[il].ssm_a;
15398
-
15399
- // use the states and the indices provided by build_rs
15400
- // (this is necessary in order to properly use the states before they are overwritten,
15401
- // while avoiding to make unnecessary copies of the states)
15402
- auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
15403
- ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
15404
-
15405
- // TODO: use semistructured matrices to implement state-space duality
15406
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
15407
- return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
15408
- };
15409
-
15410
- ggml_tensor * y_ssm = build_rs(inp->get_recr(), gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
15411
-
15412
- // store last states
15413
- ggml_build_forward_expand(gf,
15414
- ggml_cpy(ctx0,
15415
- ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
15416
- ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
15417
-
15418
- ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
15419
-
15420
- // TODO: skip computing output earlier for unused tokens
15421
-
15422
- y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
15423
- y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
15424
-
15425
- // grouped RMS norm
15426
- if (model.layers[il].ssm_norm) {
15427
- y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
15428
- y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
15429
- }
15430
-
15431
- y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
15432
-
15433
- // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
15434
- cur = build_lora_mm(model.layers[il].ssm_out, y);
15435
- }
15436
-
15437
- // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
15438
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
15439
- cb(cur, "mamba_out", il);
15440
- return cur;
15441
- }
15442
15314
};
15443
15315
15444
15316
struct llm_build_arcee : public llm_graph_context {
0 commit comments