Skip to content

Commit 0838286

Browse files
ngxsonVaibhavs10
andauthored
model : add SmolLM3 (#14581)
* Init - first pass. * Model -> ModelBase. * fix errors in conversion. * Update the graph. * up. * up. * wip * cgraph ok * rm redundant code --------- Co-authored-by: Vaibhavs10 <vaibhavs10@gmail.com>
1 parent bb4f7a9 commit 0838286

File tree

6 files changed

+233
-9
lines changed

6 files changed

+233
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6687,6 +6687,11 @@ def prepare_tensors(self):
66876687
if len(experts) > 0:
66886688
raise ValueError(f"Unprocessed experts: {experts}")
66896689

6690+
6691+
@ModelBase.register("SmolLM3ForCausalLM")
6692+
class SmolLM3Model(LlamaModel):
6693+
model_arch = gguf.MODEL_ARCH.SMOLLM3
6694+
66906695
###### CONVERSION LOGIC ######
66916696

66926697

docs/development/HOWTO-add-model.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
8383

8484
### 2. Define the model architecture in `llama.cpp`
8585

86-
The model params and tensors layout must be defined in `llama.cpp`:
87-
1. Define a new `llm_arch`
88-
2. Define the tensors layout in `LLM_TENSOR_NAMES`
89-
3. Add any non-standard metadata in `llm_load_hparams`
90-
4. Create the tensors for inference in `llm_load_tensors`
91-
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
86+
The model params and tensors layout must be defined in `llama.cpp` source files:
87+
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
88+
2. In `src/llama-arch.cpp`:
89+
- Add the architecture name to the `LLM_ARCH_NAMES` map.
90+
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
91+
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
92+
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
9293

9394
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
9495

9596
### 3. Build the GGML graph implementation
9697

97-
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
98-
99-
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
98+
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
99+
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
100+
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
101+
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
100102

101103
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
102104

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class MODEL_ARCH(IntEnum):
358358
ARCEE = auto()
359359
ERNIE4_5 = auto()
360360
HUNYUAN_MOE = auto()
361+
SMOLLM3 = auto()
361362

362363

363364
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -662,6 +663,7 @@ class MODEL_TENSOR(IntEnum):
662663
MODEL_ARCH.ARCEE: "arcee",
663664
MODEL_ARCH.ERNIE4_5: "ernie4_5",
664665
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
666+
MODEL_ARCH.SMOLLM3: "smollm3",
665667
}
666668

667669
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2234,6 +2236,22 @@ class MODEL_TENSOR(IntEnum):
22342236
MODEL_TENSOR.FFN_DOWN_SHEXP,
22352237
MODEL_TENSOR.FFN_UP_SHEXP,
22362238
],
2239+
MODEL_ARCH.SMOLLM3: [
2240+
MODEL_TENSOR.TOKEN_EMBD,
2241+
MODEL_TENSOR.OUTPUT_NORM,
2242+
MODEL_TENSOR.OUTPUT,
2243+
MODEL_TENSOR.ROPE_FREQS,
2244+
MODEL_TENSOR.ATTN_NORM,
2245+
MODEL_TENSOR.ATTN_Q,
2246+
MODEL_TENSOR.ATTN_K,
2247+
MODEL_TENSOR.ATTN_V,
2248+
MODEL_TENSOR.ATTN_OUT,
2249+
MODEL_TENSOR.ATTN_ROT_EMBD,
2250+
MODEL_TENSOR.FFN_NORM,
2251+
MODEL_TENSOR.FFN_GATE,
2252+
MODEL_TENSOR.FFN_DOWN,
2253+
MODEL_TENSOR.FFN_UP,
2254+
],
22372255
# TODO
22382256
}
22392257

src/llama-arch.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7979
{ LLM_ARCH_ARCEE, "arcee" },
8080
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
8181
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
82+
{ LLM_ARCH_SMOLLM3, "smollm3" },
8283
{ LLM_ARCH_UNKNOWN, "(unknown)" },
8384
};
8485

@@ -1724,6 +1725,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
17241725
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
17251726
},
17261727
},
1728+
{
1729+
LLM_ARCH_SMOLLM3,
1730+
{
1731+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1732+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1733+
{ LLM_TENSOR_OUTPUT, "output" },
1734+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1735+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1736+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1737+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1738+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1739+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1740+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1741+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1742+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1743+
},
1744+
},
17271745
};
17281746

17291747
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum llm_arch {
8383
LLM_ARCH_ARCEE,
8484
LLM_ARCH_ERNIE4_5,
8585
LLM_ARCH_HUNYUAN_MOE,
86+
LLM_ARCH_SMOLLM3,
8687
LLM_ARCH_UNKNOWN,
8788
};
8889

src/llama-model.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15611561
default: type = LLM_TYPE_UNKNOWN;
15621562
}
15631563
} break;
1564+
case LLM_ARCH_SMOLLM3:
1565+
{
1566+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1567+
hparams.n_no_rope_layer_step = 4;
1568+
1569+
switch (hparams.n_layer) {
1570+
case 36: type = LLM_TYPE_3B; break;
1571+
default: type = LLM_TYPE_UNKNOWN;
1572+
}
1573+
} break;
15641574
default: throw std::runtime_error("unsupported model architecture");
15651575
}
15661576

@@ -4524,6 +4534,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45244534
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
45254535
}
45264536
} break;
4537+
case LLM_ARCH_SMOLLM3:
4538+
{
4539+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4540+
4541+
// output
4542+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4543+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4544+
4545+
// if output is NULL, init from the input tok embed
4546+
if (output == NULL) {
4547+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4548+
}
4549+
4550+
for (int i = 0; i < n_layer; ++i) {
4551+
auto & layer = layers[i];
4552+
4553+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4554+
4555+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
4556+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
4557+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
4558+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
4559+
4560+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
4561+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4562+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4563+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4564+
}
4565+
} break;
45274566
default:
45284567
throw std::runtime_error("unknown architecture");
45294568
}
@@ -14846,6 +14885,142 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
1484614885
cb(cur, "result_norm", -1);
1484714886
res->t_embd = cur;
1484814887

14888+
// lm_head
14889+
cur = build_lora_mm(model.output, cur);
14890+
cb(cur, "result_output", -1);
14891+
res->t_logits = cur;
14892+
14893+
ggml_build_forward_expand(gf, cur);
14894+
}
14895+
};
14896+
14897+
struct llm_build_smollm3 : public llm_graph_context {
14898+
llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14899+
const int64_t n_embd_head = hparams.n_embd_head_v;
14900+
14901+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14902+
GGML_ASSERT(n_embd_head == hparams.n_rot);
14903+
14904+
ggml_tensor * cur;
14905+
ggml_tensor * inpL;
14906+
14907+
inpL = build_inp_embd(model.tok_embd);
14908+
14909+
// inp_pos - contains the positions
14910+
ggml_tensor * inp_pos = build_inp_pos();
14911+
14912+
auto * inp_attn = build_attn_inp_kv_unified();
14913+
14914+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14915+
14916+
ggml_tensor * inp_out_ids = build_inp_out_ids();
14917+
14918+
for (int il = 0; il < n_layer; ++il) {
14919+
ggml_tensor * inpSA = inpL;
14920+
14921+
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
14922+
14923+
// norm
14924+
cur = build_norm(inpL,
14925+
model.layers[il].attn_norm, NULL,
14926+
LLM_NORM_RMS, il);
14927+
cb(cur, "attn_norm", il);
14928+
14929+
// self-attention
14930+
{
14931+
// compute Q and K and RoPE them
14932+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14933+
cb(Qcur, "Qcur", il);
14934+
if (model.layers[il].bq) {
14935+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14936+
cb(Qcur, "Qcur", il);
14937+
}
14938+
14939+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14940+
cb(Kcur, "Kcur", il);
14941+
if (model.layers[il].bk) {
14942+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14943+
cb(Kcur, "Kcur", il);
14944+
}
14945+
14946+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14947+
cb(Vcur, "Vcur", il);
14948+
if (model.layers[il].bv) {
14949+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14950+
cb(Vcur, "Vcur", il);
14951+
}
14952+
14953+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14954+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14955+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14956+
14957+
if (use_rope) {
14958+
Qcur = ggml_rope_ext(
14959+
ctx0, Qcur, inp_pos, nullptr,
14960+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14961+
ext_factor, attn_factor, beta_fast, beta_slow
14962+
);
14963+
14964+
Kcur = ggml_rope_ext(
14965+
ctx0, Kcur, inp_pos, nullptr,
14966+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14967+
ext_factor, attn_factor, beta_fast, beta_slow
14968+
);
14969+
}
14970+
14971+
cb(Qcur, "Qcur", il);
14972+
cb(Kcur, "Kcur", il);
14973+
cb(Vcur, "Vcur", il);
14974+
14975+
cur = build_attn(inp_attn, gf,
14976+
model.layers[il].wo, model.layers[il].bo,
14977+
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14978+
cb(cur, "attn_out", il);
14979+
}
14980+
14981+
if (il == n_layer - 1 && inp_out_ids) {
14982+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14983+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14984+
}
14985+
14986+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14987+
cb(ffn_inp, "ffn_inp", il);
14988+
14989+
// feed-forward network
14990+
{
14991+
cur = build_norm(ffn_inp,
14992+
model.layers[il].ffn_norm, NULL,
14993+
LLM_NORM_RMS, il);
14994+
cb(cur, "ffn_norm", il);
14995+
14996+
cur = build_ffn(cur,
14997+
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
14998+
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
14999+
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
15000+
NULL,
15001+
LLM_FFN_SILU, LLM_FFN_PAR, il);
15002+
cb(cur, "ffn_out", il);
15003+
}
15004+
15005+
cur = ggml_add(ctx0, cur, ffn_inp);
15006+
cb(cur, "ffn_out", il);
15007+
15008+
cur = build_cvec(cur, il);
15009+
cb(cur, "l_out", il);
15010+
15011+
// input for next layer
15012+
inpL = cur;
15013+
}
15014+
15015+
cur = inpL;
15016+
15017+
cur = build_norm(cur,
15018+
model.output_norm, NULL,
15019+
LLM_NORM_RMS, -1);
15020+
15021+
cb(cur, "result_norm", -1);
15022+
res->t_embd = cur;
15023+
1484915024
// lm_head
1485015025
cur = build_lora_mm(model.output, cur);
1485115026

@@ -15240,6 +15415,10 @@ llm_graph_result_ptr llama_model::build_graph(
1524015415
{
1524115416
llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
1524215417
} break;
15418+
case LLM_ARCH_SMOLLM3:
15419+
{
15420+
llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
15421+
} break;
1524315422
default:
1524415423
GGML_ABORT("fatal error");
1524515424
}
@@ -15391,6 +15570,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1539115570
case LLM_ARCH_CHAMELEON:
1539215571
case LLM_ARCH_BAILINGMOE:
1539315572
case LLM_ARCH_NEO_BERT:
15573+
case LLM_ARCH_SMOLLM3:
1539415574
case LLM_ARCH_ARCEE:
1539515575
case LLM_ARCH_ERNIE4_5:
1539615576
return LLAMA_ROPE_TYPE_NORM;

0 commit comments

Comments
 (0)