Skip to content

Commit b1869d4

Browse files
committed
Merge remote-tracking branch 'origin/compilade/mamba2' into GraniteFour
* origin/compilade/mamba2: (24 commits) kv-cache : allow context shift for recurrent models convert : avoid AutoConfig for Mamba and Mamba2 hparams kv-cache : remove const_cast when setting inputs for s_copy metal : single-user mamba2 inference works metal : add missing args for nb references in ssm_scan_f32_group metal : fix confusion between ; and , convert : fix flake8 lint ggml : avoid multiply by D in GGML_OP_SSM_SCAN ggml : remove unused fast broadcast path in GGML_MUL metal : fix wrong number of tokens per sequence in SSM_SCAN metal : fix SSM_SCAN state head offset metal : add back n_seqs to SSM_SCAN args metal : remove unused arguments for SSM_SCAN metal : use log and exp instead of log1pf and expf in SSM_SCAN metal : fix SSM_SCAN pipeline scope metal : attempt to adapt SSM_SCAN for Mamba-2 llama : avoid redundant state copy for Mamba 1 and 2 convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present llama : add missing break llama : remove unused variable ...
2 parents 9aa941d + e94f393 commit b1869d4

22 files changed

+853
-315
lines changed

convert_hf_to_gguf.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4357,6 +4357,14 @@ def set_gguf_parameters(self):
43574357
class MambaModel(TextModel):
43584358
model_arch = gguf.MODEL_ARCH.MAMBA
43594359

4360+
def __init__(self, dir_model: Path, *args, **kwargs):
4361+
# Avoid using AutoConfig for hparams
4362+
hparams = kwargs.pop("hparams", None)
4363+
if hparams is None:
4364+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4365+
hparams = json.load(f)
4366+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4367+
43604368
def set_vocab(self):
43614369
vocab_size = self.hparams["vocab_size"]
43624370
# Round vocab size to next multiple of 8
@@ -4431,6 +4439,100 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44314439
return [(new_name, data_torch)]
44324440

44334441

4442+
@ModelBase.register("Mamba2ForCausalLM")
4443+
class Mamba2Model(TextModel):
4444+
model_arch = gguf.MODEL_ARCH.MAMBA2
4445+
4446+
def __init__(self, dir_model: Path, *args, **kwargs):
4447+
# Avoid using AutoConfig for hparams
4448+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4449+
hparams = kwargs.pop("hparams", None)
4450+
if hparams is None:
4451+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4452+
hparams = json.load(f)
4453+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4454+
4455+
def set_vocab(self):
4456+
vocab_size = self.hparams["vocab_size"]
4457+
# Round vocab size to next multiple of 16
4458+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
4459+
# pad using ceiling division
4460+
# ref: https://stackoverflow.com/a/17511341/22827863
4461+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
4462+
self.hparams["vocab_size"] = vocab_size
4463+
4464+
if (self.dir_model / "tokenizer.model").is_file():
4465+
self._set_vocab_sentencepiece()
4466+
elif (self.dir_model / "tokenizer.model.v3").is_file():
4467+
# mamba-codestral
4468+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
4469+
elif (self.dir_model / "tokenizer.json").is_file():
4470+
self._set_vocab_gpt2()
4471+
else:
4472+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
4473+
self._set_vocab_builtin("gpt-neox", vocab_size)
4474+
4475+
def set_gguf_parameters(self):
4476+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4477+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4478+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4479+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4480+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4481+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4482+
4483+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4484+
4485+
# Fail early for models which don't have a block expansion factor of 2
4486+
# TODO: does this really matter?
4487+
assert d_inner == 2 * d_model
4488+
assert d_inner % head_dim == 0
4489+
4490+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4491+
self.gguf_writer.add_embedding_length(d_model)
4492+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
4493+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
4494+
self.gguf_writer.add_block_count(self.block_count)
4495+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4496+
self.gguf_writer.add_ssm_inner_size(d_inner)
4497+
self.gguf_writer.add_ssm_state_size(d_state)
4498+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4499+
self.gguf_writer.add_ssm_group_count(n_group)
4500+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
4501+
self.gguf_writer.add_file_type(self.ftype)
4502+
4503+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4504+
4505+
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
4506+
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
4507+
name = name.removeprefix("model.")
4508+
4509+
if name.endswith(".dt_bias"):
4510+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4511+
4512+
new_name = self.map_tensor_name(name)
4513+
4514+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
4515+
data_torch = data_torch.squeeze()
4516+
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
4517+
gguf.MODEL_TENSOR.SSM_A,
4518+
gguf.MODEL_TENSOR.SSM_D,
4519+
]):
4520+
# unsqueeze A to use similar shape semantics as Mamba-1
4521+
# (D is also unsqueezed, but for more straightforward broadcast internally)
4522+
data_torch = data_torch.reshape((*data_torch.shape, 1))
4523+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4524+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4525+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4526+
n_group = self.hparams.get("n_groups", 1)
4527+
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4528+
4529+
if name.endswith(".A_log"):
4530+
logger.debug("A_log --> A ==> " + new_name)
4531+
data_torch = -torch.exp(data_torch)
4532+
4533+
yield (new_name, data_torch)
4534+
4535+
44344536
@ModelBase.register("CohereForCausalLM")
44354537
class CommandR2Model(TextModel):
44364538
model_arch = gguf.MODEL_ARCH.COMMAND_R
@@ -6136,12 +6238,20 @@ def split_str_to_n_bytes(split_str: str) -> int:
61366238
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
61376239
text_config = hparams.get("text_config", {})
61386240
vision_config = hparams.get("vision_config", {})
6139-
arch = hparams["architectures"][0]
6241+
arch = None
6242+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
6243+
arch = arches[0]
6244+
elif "ssm_cfg" in hparams:
6245+
# For non-hf Mamba and Mamba2 models
6246+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
6247+
61406248
# if "architectures" is found in the sub-config, use that instead
61416249
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
61426250
arch = text_config["architectures"][0]
61436251
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
61446252
arch = vision_config["architectures"][0]
6253+
if arch is None:
6254+
raise ValueError("Failed to detect model architecture")
61456255
return arch
61466256

61476257

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,8 @@ extern "C" {
18581858
struct ggml_tensor * dt,
18591859
struct ggml_tensor * A,
18601860
struct ggml_tensor * B,
1861-
struct ggml_tensor * C);
1861+
struct ggml_tensor * C,
1862+
struct ggml_tensor * ids);
18621863

18631864
// partition into non-overlapping windows with padding if needed
18641865
// example:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7493,74 +7493,151 @@ void ggml_compute_forward_ssm_conv(
74937493
static void ggml_compute_forward_ssm_scan_f32(
74947494
const ggml_compute_params * params,
74957495
ggml_tensor * dst) {
7496-
const ggml_tensor * src0 = dst->src[0]; // s
7497-
const ggml_tensor * src1 = dst->src[1]; // x
7498-
const ggml_tensor * src2 = dst->src[2]; // dt
7499-
const ggml_tensor * src3 = dst->src[3]; // A
7500-
const ggml_tensor * src4 = dst->src[4]; // B
7501-
const ggml_tensor * src5 = dst->src[5]; // C
7496+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
7497+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
7498+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
7499+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
7500+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
7501+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
7502+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
75027503

75037504
const int ith = params->ith;
75047505
const int nth = params->nth;
75057506

7506-
const int64_t nc = src0->ne[0]; // d_state
7507-
const int64_t nr = src0->ne[1]; // d_inner
7508-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
7509-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
7507+
const int64_t nc = src0->ne[0]; // d_state
7508+
const int64_t nr = src0->ne[1]; // dim
7509+
const int64_t nh = src1->ne[1]; // n_head
7510+
const int64_t ng = src4->ne[1];
7511+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
7512+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
75107513

7511-
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
7514+
// can't use ggml_nbytes because src1 is not necessarily contiguous
7515+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
7516+
7517+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
75127518
GGML_ASSERT(src0->nb[0] == sizeof(float));
75137519
GGML_ASSERT(src1->nb[0] == sizeof(float));
75147520
GGML_ASSERT(src2->nb[0] == sizeof(float));
75157521
GGML_ASSERT(src3->nb[0] == sizeof(float));
75167522
GGML_ASSERT(src4->nb[0] == sizeof(float));
75177523
GGML_ASSERT(src5->nb[0] == sizeof(float));
7518-
// required for the dot product between s and C
7519-
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
7520-
// required for per-sequence offsets for states
7521-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
7522-
// required to get correct offset for state destination (i.e. src1->nb[3])
7523-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
7524+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
7525+
// allows optimizing the modulo since n_group should be a power of 2
7526+
GGML_ASSERT((ng & -ng) == ng);
7527+
7528+
// heads per thread
7529+
const int dh = (nh + nth - 1)/nth;
7530+
7531+
// head range for this thread
7532+
const int ih0 = dh*ith;
7533+
const int ih1 = MIN(ih0 + dh, nh);
7534+
7535+
const int32_t * ids = (const int32_t *) src6->data;
7536+
7537+
for (int i3 = 0; i3 < ns; ++i3) {
7538+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
7539+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
7540+
7541+
for (int i2 = 0; i2 < nt; ++i2) {
7542+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
7543+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
7544+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
7545+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
7546+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
7547+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
7548+
7549+
if (src3->ne[0] == 1) {
7550+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
7551+
7552+
// n_head
7553+
for (int h = ih0; h < ih1; ++h) {
7554+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7555+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7556+
const float dA = expf(dt_soft_plus * A[h]);
7557+
7558+
// dim
7559+
for (int i1 = 0; i1 < nr; ++i1) {
7560+
const int ii = i1 + h*nr;
7561+
const float x_dt = x[ii] * dt_soft_plus;
7562+
float sumf = 0.0f;
7563+
#if defined(GGML_SIMD)
7564+
const int np = (nc & ~(GGML_F32_STEP - 1));
75247565

7525-
// rows per thread
7526-
const int dr = (nr + nth - 1)/nth;
7566+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
75277567

7528-
// row range for this thread
7529-
const int ir0 = dr*ith;
7530-
const int ir1 = MIN(ir0 + dr, nr);
7531-
const int ir = ir1 - ir0;
7568+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
7569+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
75327570

7533-
for (int i3 = 0; i3 < n_s; ++i3) {
7534-
for (int i2 = 0; i2 < n_t; ++i2) {
7535-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7536-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7537-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7538-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7539-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7540-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7541-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7542-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7543-
7544-
// use the output as the source for the next token-wise iterations
7545-
if (i2 > 0) { s0 = s; }
7571+
GGML_F32_VEC ax[GGML_F32_ARR];
7572+
GGML_F32_VEC ay[GGML_F32_ARR];
7573+
GGML_F32_VEC az[GGML_F32_ARR];
75467574

7547-
// d_inner
7548-
for (int i1 = 0; i1 < ir; ++i1) {
7549-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7550-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7551-
float x_dt = x[i1] * dt_soft_plus;
7552-
float sumf = 0.0f;
7553-
// d_state
7554-
for (int i0 = 0; i0 < nc; ++i0) {
7555-
int i = i0 + i1*nc;
7556-
// state = prev_state * dA + dB * x
7557-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7558-
// y = rowwise_dotprod(state, C)
7559-
sumf += state * C[i0];
7560-
s[i] = state;
7575+
for (int i = 0; i < np; i += GGML_F32_STEP) {
7576+
for (int j = 0; j < GGML_F32_ARR; j++) {
7577+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
7578+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7579+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7580+
7581+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
7582+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
7583+
7584+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
7585+
7586+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
7587+
7588+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
7589+
}
7590+
}
7591+
7592+
// reduce sum0..sum3 to sum0
7593+
GGML_F32_VEC_REDUCE(sumf, sum);
7594+
#else
7595+
const int np = 0;
7596+
#endif
7597+
// d_state
7598+
for (int i0 = np; i0 < nc; ++i0) {
7599+
const int i = i0 + ii*nc;
7600+
const int ig = i0 + (h & (ng - 1))*nc;
7601+
// state = prev_state * dA + dB * x
7602+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
7603+
// y = rowwise_dotprod(state, C)
7604+
sumf += state * C[ig];
7605+
s[i] = state;
7606+
}
7607+
y[ii] = sumf;
7608+
}
7609+
}
7610+
} else {
7611+
// Mamba-1 has an element-wise decay factor for the states
7612+
7613+
// n_head
7614+
for (int h = ih0; h < ih1; ++h) {
7615+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7616+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7617+
7618+
// dim
7619+
for (int i1 = 0; i1 < nr; ++i1) {
7620+
const int ii = i1 + h*nr;
7621+
const float x_dt = x[ii] * dt_soft_plus;
7622+
float sumf = 0.0f;
7623+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
7624+
// and also because expf is used within the loop.
7625+
// d_state
7626+
for (int i0 = 0; i0 < nc; ++i0) {
7627+
const int i = i0 + ii*nc;
7628+
const int ig = i0 + (h & (ng - 1))*nc;
7629+
// state = prev_state * dA + dB * x
7630+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
7631+
// y = rowwise_dotprod(state, C)
7632+
sumf += state * C[ig];
7633+
s[i] = state;
7634+
}
7635+
y[ii] = sumf;
7636+
}
75617637
}
7562-
y[i1] = sumf;
75637638
}
7639+
// use the output as the source when it's not the first token-wise iteration
7640+
s0 = s;
75647641
}
75657642
}
75667643
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,26 +488,25 @@ typedef struct {
488488
typedef struct {
489489
int64_t d_state;
490490
int64_t d_inner;
491+
int64_t n_head;
492+
int64_t n_group;
491493
int64_t n_seq_tokens;
492494
int64_t n_seqs;
493-
uint64_t nb00;
494495
uint64_t nb01;
495496
uint64_t nb02;
496-
uint64_t nb10;
497+
uint64_t nb03;
497498
uint64_t nb11;
498499
uint64_t nb12;
499500
uint64_t nb13;
500-
uint64_t nb20;
501501
uint64_t nb21;
502502
uint64_t nb22;
503-
uint64_t nb30;
504503
uint64_t nb31;
505-
uint64_t nb40;
506504
uint64_t nb41;
507505
uint64_t nb42;
508-
uint64_t nb50;
506+
uint64_t nb43;
509507
uint64_t nb51;
510508
uint64_t nb52;
509+
uint64_t nb53;
511510
} ggml_metal_kargs_ssm_scan;
512511

513512
typedef struct {

0 commit comments

Comments
 (0)