Skip to content

Commit cdc2c80

Browse files
authored
[Model] Update Gemma3 to support 1b variant (#3178)
This PR updates the Gemma3 weight loader implementation to support the 1b variant.
1 parent a1fa8be commit cdc2c80

File tree

4 files changed

+62
-46
lines changed

4 files changed

+62
-46
lines changed

python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def create_flashinfer_paged_kv_cache(
179179
if ( # pylint: disable=too-many-boolean-expressions
180180
not self.flashinfer
181181
or self.target.kind.name != "cuda"
182-
or str(kwargs["dtype"]) not in ["float16"]
182+
or str(kwargs["dtype"]) not in ["float16", "bfloat16"]
183183
or (
184184
kwargs["rope_mode"] == RopeMode.INLINE
185185
and (

python/mlc_llm/model/gemma3/gemma3_loader.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,18 @@ def huggingface(model_config: Gemma3Config, quantization: Quantization) -> Exter
4141

4242
mapping = ExternMapping()
4343

44+
mlc_prefix = "language_model."
45+
hf_prefix = "language_model." if not model_config.is_text_model else ""
4446
for i in range(model_config.text_config.num_hidden_layers):
4547
# Add gates in MLP
46-
mlp = f"language_model.model.layers.{i}.mlp"
47-
mlc_name = f"{mlp}.gate_up_proj.weight"
48+
mlp = f"model.layers.{i}.mlp"
49+
mlc_name = f"{mlc_prefix + mlp}.gate_up_proj.weight"
4850
mlc_param = named_parameters[mlc_name]
4951
mapping.add_mapping(
5052
mlc_name,
5153
[
52-
f"{mlp}.gate_proj.weight",
53-
f"{mlp}.up_proj.weight",
54+
f"{hf_prefix + mlp}.gate_proj.weight",
55+
f"{hf_prefix + mlp}.up_proj.weight",
5456
],
5557
functools.partial(
5658
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
@@ -59,88 +61,88 @@ def huggingface(model_config: Gemma3Config, quantization: Quantization) -> Exter
5961
)
6062
# Modify RMS layernorm weights, since Gemma model adds 1 to the weights
6163
# We add 1 to the weights here for efficiency purpose
62-
mlc_name = f"language_model.model.layers.{i}.input_layernorm.weight"
63-
mlc_param = named_parameters[mlc_name]
64+
mlc_name = f"model.layers.{i}.input_layernorm.weight"
65+
mlc_param = named_parameters[mlc_prefix + mlc_name]
6466
mapping.add_mapping(
65-
mlc_name,
66-
[mlc_name],
67+
mlc_prefix + mlc_name,
68+
[hf_prefix + mlc_name],
6769
functools.partial(
6870
lambda x, dtype: (x + 1).astype(dtype),
69-
dtype=named_parameters[mlc_name].dtype,
71+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
7072
),
7173
)
7274

73-
mlc_name = f"language_model.model.layers.{i}.post_attention_layernorm.weight"
74-
mlc_param = named_parameters[mlc_name]
75+
mlc_name = f"model.layers.{i}.post_attention_layernorm.weight"
76+
mlc_param = named_parameters[mlc_prefix + mlc_name]
7577
mapping.add_mapping(
76-
mlc_name,
77-
[mlc_name],
78+
mlc_prefix + mlc_name,
79+
[hf_prefix + mlc_name],
7880
functools.partial(
7981
lambda x, dtype: (x + 1).astype(dtype),
80-
dtype=named_parameters[mlc_name].dtype,
82+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
8183
),
8284
)
8385

84-
mlc_name = f"language_model.model.layers.{i}.pre_feedforward_layernorm.weight"
85-
mlc_param = named_parameters[mlc_name]
86+
mlc_name = f"model.layers.{i}.pre_feedforward_layernorm.weight"
87+
mlc_param = named_parameters[mlc_prefix + mlc_name]
8688
mapping.add_mapping(
87-
mlc_name,
88-
[mlc_name],
89+
mlc_prefix + mlc_name,
90+
[hf_prefix + mlc_name],
8991
functools.partial(
9092
lambda x, dtype: (x + 1).astype(dtype),
91-
dtype=named_parameters[mlc_name].dtype,
93+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
9294
),
9395
)
9496

95-
mlc_name = f"language_model.model.layers.{i}.post_feedforward_layernorm.weight"
96-
mlc_param = named_parameters[mlc_name]
97+
mlc_name = f"model.layers.{i}.post_feedforward_layernorm.weight"
98+
mlc_param = named_parameters[mlc_prefix + mlc_name]
9799
mapping.add_mapping(
98-
mlc_name,
99-
[mlc_name],
100+
mlc_prefix + mlc_name,
101+
[hf_prefix + mlc_name],
100102
functools.partial(
101103
lambda x, dtype: (x + 1).astype(dtype),
102-
dtype=named_parameters[mlc_name].dtype,
104+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
103105
),
104106
)
105107

106-
mlc_name = f"language_model.model.layers.{i}.self_attn.k_norm.weight"
107-
mlc_param = named_parameters[mlc_name]
108+
mlc_name = f"model.layers.{i}.self_attn.k_norm.weight"
109+
mlc_param = named_parameters[mlc_prefix + mlc_name]
108110
mapping.add_mapping(
109-
mlc_name,
110-
[mlc_name],
111+
mlc_prefix + mlc_name,
112+
[hf_prefix + mlc_name],
111113
functools.partial(
112114
lambda x, dtype: (x + 1).astype(dtype),
113-
dtype=named_parameters[mlc_name].dtype,
115+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
114116
),
115117
)
116118

117-
mlc_name = f"language_model.model.layers.{i}.self_attn.q_norm.weight"
118-
mlc_param = named_parameters[mlc_name]
119+
mlc_name = f"model.layers.{i}.self_attn.q_norm.weight"
120+
mlc_param = named_parameters[mlc_prefix + mlc_name]
119121
mapping.add_mapping(
120-
mlc_name,
121-
[mlc_name],
122+
mlc_prefix + mlc_name,
123+
[hf_prefix + mlc_name],
122124
functools.partial(
123125
lambda x, dtype: (x + 1).astype(dtype),
124-
dtype=named_parameters[mlc_name].dtype,
126+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
125127
),
126128
)
127129

128-
mlc_name = "language_model.model.norm.weight"
129-
mlc_param = named_parameters[mlc_name]
130+
mlc_name = "model.norm.weight"
131+
mlc_param = named_parameters[mlc_prefix + mlc_name]
130132
mapping.add_mapping(
131-
mlc_name,
132-
[mlc_name],
133+
mlc_prefix + mlc_name,
134+
[hf_prefix + mlc_name],
133135
functools.partial(
134136
lambda x, dtype: (x + 1).astype(dtype),
135-
dtype=named_parameters[mlc_name].dtype,
137+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
136138
),
137139
)
138140

139141
for mlc_name, mlc_param in named_parameters.items():
140142
if mlc_name not in mapping.param_map:
141143
mapping.add_mapping(
142144
mlc_name,
143-
[mlc_name],
145+
[hf_prefix + mlc_name[len(mlc_prefix) :]],
144146
functools.partial(
145147
lambda x, dtype: x.astype(dtype),
146148
dtype=mlc_param.dtype,

python/mlc_llm/model/gemma3/gemma3_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,21 @@ def __post_init__(self):
9595
class Gemma3Config(ConfigBase): # pylint: disable=too-many-instance-attributes
9696
"""Configuration of the Gemma3 model"""
9797

98-
text_config: Gemma3TextConfig
98+
text_config: Gemma3TextConfig = None
9999
vocab_size: int = 262_208
100100
tensor_parallel_shards: int = 1
101101
max_batch_size: int = 1
102102
context_window_size: int = -1
103103
sliding_window_size: int = -1
104104
prefill_chunk_size: int = -1
105+
is_text_model: bool = False
105106
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
106107

107108
def __post_init__(self):
109+
if self.text_config is None:
110+
self.is_text_model = True
111+
self.text_config = Gemma3TextConfig.from_dict(self.kwargs)
112+
108113
text_config_dict: Dict[str, Any]
109114
if isinstance(self.text_config, Gemma3TextConfig):
110115
text_config_dict = dataclasses.asdict(self.text_config)
@@ -121,10 +126,6 @@ def __post_init__(self):
121126
if hasattr(self.text_config, k):
122127
setattr(self, k, getattr(self.text_config, k))
123128

124-
# if getattr(self, "sliding_window_size") <= 0:
125-
# if hasattr(self.text_config, "sliding_window"):
126-
# setattr(self, "sliding_window_size", getattr(self.text_config, "sliding_window"))
127-
128129

129130
# pylint: disable=invalid-name,missing-docstring
130131

python/mlc_llm/model/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ class Model:
158158
"group-quant": gemma3_quantization.group_quant,
159159
},
160160
),
161+
"gemma3_text": Model(
162+
name="gemma3_text",
163+
model=gemma3_model.Gemma3ForCausalLM,
164+
config=gemma3_model.Gemma3Config,
165+
source={
166+
"huggingface-torch": gemma3_loader.huggingface,
167+
"huggingface-safetensor": gemma3_loader.huggingface,
168+
},
169+
quantize={
170+
"no-quant": gemma3_quantization.no_quant,
171+
"group-quant": gemma3_quantization.group_quant,
172+
},
173+
),
161174
"gpt2": Model(
162175
name="gpt2",
163176
model=gpt2_model.GPT2LMHeadModel,

0 commit comments

Comments
 (0)