Skip to content

Commit caeb8ee

Browse files
committed
Merge branch 'compilade/mamba2' into GraniteFour
* compilade/mamba2: (24 commits) kv-cache : allow context shift for recurrent models convert : avoid AutoConfig for Mamba and Mamba2 hparams kv-cache : remove const_cast when setting inputs for s_copy metal : single-user mamba2 inference works metal : add missing args for nb references in ssm_scan_f32_group metal : fix confusion between ; and , convert : fix flake8 lint ggml : avoid multiply by D in GGML_OP_SSM_SCAN ggml : remove unused fast broadcast path in GGML_MUL metal : fix wrong number of tokens per sequence in SSM_SCAN metal : fix SSM_SCAN state head offset metal : add back n_seqs to SSM_SCAN args metal : remove unused arguments for SSM_SCAN metal : use log and exp instead of log1pf and expf in SSM_SCAN metal : fix SSM_SCAN pipeline scope metal : attempt to adapt SSM_SCAN for Mamba-2 llama : avoid redundant state copy for Mamba 1 and 2 convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present llama : add missing break llama : remove unused variable ...
2 parents 089c968 + e94f393 commit caeb8ee

22 files changed

+859
-326
lines changed

convert_hf_to_gguf.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4386,6 +4386,14 @@ def set_gguf_parameters(self):
43864386
class MambaModel(TextModel):
43874387
model_arch = gguf.MODEL_ARCH.MAMBA
43884388

4389+
def __init__(self, dir_model: Path, *args, **kwargs):
4390+
# Avoid using AutoConfig for hparams
4391+
hparams = kwargs.pop("hparams", None)
4392+
if hparams is None:
4393+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4394+
hparams = json.load(f)
4395+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4396+
43894397
def set_vocab(self):
43904398
vocab_size = self.hparams["vocab_size"]
43914399
# Round vocab size to next multiple of 8
@@ -4460,6 +4468,100 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44604468
return [(new_name, data_torch)]
44614469

44624470

4471+
@ModelBase.register("Mamba2ForCausalLM")
4472+
class Mamba2Model(TextModel):
4473+
model_arch = gguf.MODEL_ARCH.MAMBA2
4474+
4475+
def __init__(self, dir_model: Path, *args, **kwargs):
4476+
# Avoid using AutoConfig for hparams
4477+
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4478+
hparams = kwargs.pop("hparams", None)
4479+
if hparams is None:
4480+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
4481+
hparams = json.load(f)
4482+
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4483+
4484+
def set_vocab(self):
4485+
vocab_size = self.hparams["vocab_size"]
4486+
# Round vocab size to next multiple of 16
4487+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
4488+
# pad using ceiling division
4489+
# ref: https://stackoverflow.com/a/17511341/22827863
4490+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
4491+
self.hparams["vocab_size"] = vocab_size
4492+
4493+
if (self.dir_model / "tokenizer.model").is_file():
4494+
self._set_vocab_sentencepiece()
4495+
elif (self.dir_model / "tokenizer.model.v3").is_file():
4496+
# mamba-codestral
4497+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
4498+
elif (self.dir_model / "tokenizer.json").is_file():
4499+
self._set_vocab_gpt2()
4500+
else:
4501+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
4502+
self._set_vocab_builtin("gpt-neox", vocab_size)
4503+
4504+
def set_gguf_parameters(self):
4505+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4506+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4507+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4508+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4509+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
4510+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4511+
4512+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4513+
4514+
# Fail early for models which don't have a block expansion factor of 2
4515+
# TODO: does this really matter?
4516+
assert d_inner == 2 * d_model
4517+
assert d_inner % head_dim == 0
4518+
4519+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4520+
self.gguf_writer.add_embedding_length(d_model)
4521+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
4522+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
4523+
self.gguf_writer.add_block_count(self.block_count)
4524+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4525+
self.gguf_writer.add_ssm_inner_size(d_inner)
4526+
self.gguf_writer.add_ssm_state_size(d_state)
4527+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4528+
self.gguf_writer.add_ssm_group_count(n_group)
4529+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
4530+
self.gguf_writer.add_file_type(self.ftype)
4531+
4532+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4533+
4534+
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
4535+
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
4536+
name = name.removeprefix("model.")
4537+
4538+
if name.endswith(".dt_bias"):
4539+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
4540+
4541+
new_name = self.map_tensor_name(name)
4542+
4543+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
4544+
data_torch = data_torch.squeeze()
4545+
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
4546+
gguf.MODEL_TENSOR.SSM_A,
4547+
gguf.MODEL_TENSOR.SSM_D,
4548+
]):
4549+
# unsqueeze A to use similar shape semantics as Mamba-1
4550+
# (D is also unsqueezed, but for more straightforward broadcast internally)
4551+
data_torch = data_torch.reshape((*data_torch.shape, 1))
4552+
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4553+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4554+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4555+
n_group = self.hparams.get("n_groups", 1)
4556+
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4557+
4558+
if name.endswith(".A_log"):
4559+
logger.debug("A_log --> A ==> " + new_name)
4560+
data_torch = -torch.exp(data_torch)
4561+
4562+
yield (new_name, data_torch)
4563+
4564+
44634565
@ModelBase.register("CohereForCausalLM")
44644566
class CommandR2Model(TextModel):
44654567
model_arch = gguf.MODEL_ARCH.COMMAND_R
@@ -6226,12 +6328,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
62266328
# maybe we should fallback to text model's arch in that case, since not many models have both
62276329
text_config = hparams.get("text_config", {})
62286330
vision_config = hparams.get("vision_config", {})
6229-
arch = hparams["architectures"][0]
6331+
arch = None
6332+
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
6333+
arch = arches[0]
6334+
elif "ssm_cfg" in hparams:
6335+
# For non-hf Mamba and Mamba2 models
6336+
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
6337+
62306338
# if "architectures" is found in the sub-config, use that instead
62316339
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
62326340
arch = text_config["architectures"][0]
62336341
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
62346342
arch = vision_config["architectures"][0]
6343+
if arch is None:
6344+
raise ValueError("Failed to detect model architecture")
62356345
return arch
62366346

62376347

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,8 @@ extern "C" {
18691869
struct ggml_tensor * dt,
18701870
struct ggml_tensor * A,
18711871
struct ggml_tensor * B,
1872-
struct ggml_tensor * C);
1872+
struct ggml_tensor * C,
1873+
struct ggml_tensor * ids);
18731874

18741875
// partition into non-overlapping windows with padding if needed
18751876
// example:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7596,74 +7596,151 @@ void ggml_compute_forward_ssm_conv(
75967596
static void ggml_compute_forward_ssm_scan_f32(
75977597
const ggml_compute_params * params,
75987598
ggml_tensor * dst) {
7599-
const ggml_tensor * src0 = dst->src[0]; // s
7600-
const ggml_tensor * src1 = dst->src[1]; // x
7601-
const ggml_tensor * src2 = dst->src[2]; // dt
7602-
const ggml_tensor * src3 = dst->src[3]; // A
7603-
const ggml_tensor * src4 = dst->src[4]; // B
7604-
const ggml_tensor * src5 = dst->src[5]; // C
7599+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
7600+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
7601+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
7602+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
7603+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
7604+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
7605+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
76057606

76067607
const int ith = params->ith;
76077608
const int nth = params->nth;
76087609

7609-
const int64_t nc = src0->ne[0]; // d_state
7610-
const int64_t nr = src0->ne[1]; // d_inner
7611-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
7612-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
7610+
const int64_t nc = src0->ne[0]; // d_state
7611+
const int64_t nr = src0->ne[1]; // dim
7612+
const int64_t nh = src1->ne[1]; // n_head
7613+
const int64_t ng = src4->ne[1];
7614+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
7615+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
76137616

7614-
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
7617+
// can't use ggml_nbytes because src1 is not necessarily contiguous
7618+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
7619+
7620+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
76157621
GGML_ASSERT(src0->nb[0] == sizeof(float));
76167622
GGML_ASSERT(src1->nb[0] == sizeof(float));
76177623
GGML_ASSERT(src2->nb[0] == sizeof(float));
76187624
GGML_ASSERT(src3->nb[0] == sizeof(float));
76197625
GGML_ASSERT(src4->nb[0] == sizeof(float));
76207626
GGML_ASSERT(src5->nb[0] == sizeof(float));
7621-
// required for the dot product between s and C
7622-
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
7623-
// required for per-sequence offsets for states
7624-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
7625-
// required to get correct offset for state destination (i.e. src1->nb[3])
7626-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
7627+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
7628+
// allows optimizing the modulo since n_group should be a power of 2
7629+
GGML_ASSERT((ng & -ng) == ng);
7630+
7631+
// heads per thread
7632+
const int dh = (nh + nth - 1)/nth;
7633+
7634+
// head range for this thread
7635+
const int ih0 = dh*ith;
7636+
const int ih1 = MIN(ih0 + dh, nh);
7637+
7638+
const int32_t * ids = (const int32_t *) src6->data;
7639+
7640+
for (int i3 = 0; i3 < ns; ++i3) {
7641+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
7642+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
7643+
7644+
for (int i2 = 0; i2 < nt; ++i2) {
7645+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
7646+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
7647+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
7648+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
7649+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
7650+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
7651+
7652+
if (src3->ne[0] == 1) {
7653+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
7654+
7655+
// n_head
7656+
for (int h = ih0; h < ih1; ++h) {
7657+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7658+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7659+
const float dA = expf(dt_soft_plus * A[h]);
7660+
7661+
// dim
7662+
for (int i1 = 0; i1 < nr; ++i1) {
7663+
const int ii = i1 + h*nr;
7664+
const float x_dt = x[ii] * dt_soft_plus;
7665+
float sumf = 0.0f;
7666+
#if defined(GGML_SIMD)
7667+
const int np = (nc & ~(GGML_F32_STEP - 1));
76277668

7628-
// rows per thread
7629-
const int dr = (nr + nth - 1)/nth;
7669+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
76307670

7631-
// row range for this thread
7632-
const int ir0 = dr*ith;
7633-
const int ir1 = MIN(ir0 + dr, nr);
7634-
const int ir = ir1 - ir0;
7671+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
7672+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
76357673

7636-
for (int i3 = 0; i3 < n_s; ++i3) {
7637-
for (int i2 = 0; i2 < n_t; ++i2) {
7638-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7639-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7640-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7641-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7642-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7643-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7646-
7647-
// use the output as the source for the next token-wise iterations
7648-
if (i2 > 0) { s0 = s; }
7674+
GGML_F32_VEC ax[GGML_F32_ARR];
7675+
GGML_F32_VEC ay[GGML_F32_ARR];
7676+
GGML_F32_VEC az[GGML_F32_ARR];
76497677

7650-
// d_inner
7651-
for (int i1 = 0; i1 < ir; ++i1) {
7652-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654-
float x_dt = x[i1] * dt_soft_plus;
7655-
float sumf = 0.0f;
7656-
// d_state
7657-
for (int i0 = 0; i0 < nc; ++i0) {
7658-
int i = i0 + i1*nc;
7659-
// state = prev_state * dA + dB * x
7660-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661-
// y = rowwise_dotprod(state, C)
7662-
sumf += state * C[i0];
7663-
s[i] = state;
7678+
for (int i = 0; i < np; i += GGML_F32_STEP) {
7679+
for (int j = 0; j < GGML_F32_ARR; j++) {
7680+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
7681+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7682+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
7683+
7684+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
7685+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
7686+
7687+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
7688+
7689+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
7690+
7691+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
7692+
}
7693+
}
7694+
7695+
// reduce sum0..sum3 to sum0
7696+
GGML_F32_VEC_REDUCE(sumf, sum);
7697+
#else
7698+
const int np = 0;
7699+
#endif
7700+
// d_state
7701+
for (int i0 = np; i0 < nc; ++i0) {
7702+
const int i = i0 + ii*nc;
7703+
const int ig = i0 + (h & (ng - 1))*nc;
7704+
// state = prev_state * dA + dB * x
7705+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
7706+
// y = rowwise_dotprod(state, C)
7707+
sumf += state * C[ig];
7708+
s[i] = state;
7709+
}
7710+
y[ii] = sumf;
7711+
}
7712+
}
7713+
} else {
7714+
// Mamba-1 has an element-wise decay factor for the states
7715+
7716+
// n_head
7717+
for (int h = ih0; h < ih1; ++h) {
7718+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
7719+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
7720+
7721+
// dim
7722+
for (int i1 = 0; i1 < nr; ++i1) {
7723+
const int ii = i1 + h*nr;
7724+
const float x_dt = x[ii] * dt_soft_plus;
7725+
float sumf = 0.0f;
7726+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
7727+
// and also because expf is used within the loop.
7728+
// d_state
7729+
for (int i0 = 0; i0 < nc; ++i0) {
7730+
const int i = i0 + ii*nc;
7731+
const int ig = i0 + (h & (ng - 1))*nc;
7732+
// state = prev_state * dA + dB * x
7733+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
7734+
// y = rowwise_dotprod(state, C)
7735+
sumf += state * C[ig];
7736+
s[i] = state;
7737+
}
7738+
y[ii] = sumf;
7739+
}
76647740
}
7665-
y[i1] = sumf;
76667741
}
7742+
// use the output as the source when it's not the first token-wise iteration
7743+
s0 = s;
76677744
}
76687745
}
76697746
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,26 +488,25 @@ typedef struct {
488488
typedef struct {
489489
int64_t d_state;
490490
int64_t d_inner;
491+
int64_t n_head;
492+
int64_t n_group;
491493
int64_t n_seq_tokens;
492494
int64_t n_seqs;
493-
uint64_t nb00;
494495
uint64_t nb01;
495496
uint64_t nb02;
496-
uint64_t nb10;
497+
uint64_t nb03;
497498
uint64_t nb11;
498499
uint64_t nb12;
499500
uint64_t nb13;
500-
uint64_t nb20;
501501
uint64_t nb21;
502502
uint64_t nb22;
503-
uint64_t nb30;
504503
uint64_t nb31;
505-
uint64_t nb40;
506504
uint64_t nb41;
507505
uint64_t nb42;
508-
uint64_t nb50;
506+
uint64_t nb43;
509507
uint64_t nb51;
510508
uint64_t nb52;
509+
uint64_t nb53;
511510
} ggml_metal_kargs_ssm_scan;
512511

513512
typedef struct {

0 commit comments

Comments
 (0)