Skip to content

Commit 547740a

Browse files
[Model] Gemma3 Support (#3172)
Adds model support for Gemma3, currently in progress. QK normalization has been added as described in the technical documentation for Gemma3, but current there is an issue with short/truncated generation and significant slowdown after the first few prompts. Additionally, there is also a lack of local/global attention layer interweaving, similar to Gemma2 in MLC-LLM, as well as support for the multimodal capabilities of Gemma3
1 parent bd72d21 commit 547740a

File tree

5 files changed

+841
-0
lines changed

5 files changed

+841
-0
lines changed

python/mlc_llm/model/gemma3/__init__.py

Whitespace-only changes.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
This file specifies how MLC's Gemma3 parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .gemma3_model import Gemma3Config, Gemma3ForCausalLM
14+
15+
16+
def huggingface(model_config: Gemma3Config, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : Gemma3Config
23+
The configuration of the Gemma model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = Gemma3ForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.text_config.num_hidden_layers):
45+
# Add gates in MLP
46+
mlp = f"language_model.model.layers.{i}.mlp"
47+
mlc_name = f"{mlp}.gate_up_proj.weight"
48+
mlc_param = named_parameters[mlc_name]
49+
mapping.add_mapping(
50+
mlc_name,
51+
[
52+
f"{mlp}.gate_proj.weight",
53+
f"{mlp}.up_proj.weight",
54+
],
55+
functools.partial(
56+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
57+
dtype=mlc_param.dtype,
58+
),
59+
)
60+
# Modify RMS layernorm weights, since Gemma model adds 1 to the weights
61+
# We add 1 to the weights here for efficiency purpose
62+
mlc_name = f"language_model.model.layers.{i}.input_layernorm.weight"
63+
mlc_param = named_parameters[mlc_name]
64+
mapping.add_mapping(
65+
mlc_name,
66+
[mlc_name],
67+
functools.partial(
68+
lambda x, dtype: (x + 1).astype(dtype),
69+
dtype=named_parameters[mlc_name].dtype,
70+
),
71+
)
72+
73+
mlc_name = f"language_model.model.layers.{i}.post_attention_layernorm.weight"
74+
mlc_param = named_parameters[mlc_name]
75+
mapping.add_mapping(
76+
mlc_name,
77+
[mlc_name],
78+
functools.partial(
79+
lambda x, dtype: (x + 1).astype(dtype),
80+
dtype=named_parameters[mlc_name].dtype,
81+
),
82+
)
83+
84+
mlc_name = f"language_model.model.layers.{i}.pre_feedforward_layernorm.weight"
85+
mlc_param = named_parameters[mlc_name]
86+
mapping.add_mapping(
87+
mlc_name,
88+
[mlc_name],
89+
functools.partial(
90+
lambda x, dtype: (x + 1).astype(dtype),
91+
dtype=named_parameters[mlc_name].dtype,
92+
),
93+
)
94+
95+
mlc_name = f"language_model.model.layers.{i}.post_feedforward_layernorm.weight"
96+
mlc_param = named_parameters[mlc_name]
97+
mapping.add_mapping(
98+
mlc_name,
99+
[mlc_name],
100+
functools.partial(
101+
lambda x, dtype: (x + 1).astype(dtype),
102+
dtype=named_parameters[mlc_name].dtype,
103+
),
104+
)
105+
106+
mlc_name = f"language_model.model.layers.{i}.self_attn.k_norm.weight"
107+
mlc_param = named_parameters[mlc_name]
108+
mapping.add_mapping(
109+
mlc_name,
110+
[mlc_name],
111+
functools.partial(
112+
lambda x, dtype: (x + 1).astype(dtype),
113+
dtype=named_parameters[mlc_name].dtype,
114+
),
115+
)
116+
117+
mlc_name = f"language_model.model.layers.{i}.self_attn.q_norm.weight"
118+
mlc_param = named_parameters[mlc_name]
119+
mapping.add_mapping(
120+
mlc_name,
121+
[mlc_name],
122+
functools.partial(
123+
lambda x, dtype: (x + 1).astype(dtype),
124+
dtype=named_parameters[mlc_name].dtype,
125+
),
126+
)
127+
128+
mlc_name = "language_model.model.norm.weight"
129+
mlc_param = named_parameters[mlc_name]
130+
mapping.add_mapping(
131+
mlc_name,
132+
[mlc_name],
133+
functools.partial(
134+
lambda x, dtype: (x + 1).astype(dtype),
135+
dtype=named_parameters[mlc_name].dtype,
136+
),
137+
)
138+
139+
for mlc_name, mlc_param in named_parameters.items():
140+
if mlc_name not in mapping.param_map:
141+
mapping.add_mapping(
142+
mlc_name,
143+
[mlc_name],
144+
functools.partial(
145+
lambda x, dtype: x.astype(dtype),
146+
dtype=mlc_param.dtype,
147+
),
148+
)
149+
return mapping

0 commit comments

Comments
 (0)