Skip to content

Commit 8a84e59

Browse files
authored
Merge branch 'mlc-ai:main' into main
2 parents 712f1d5 + ab946b8 commit 8a84e59

File tree

7 files changed

+859
-2
lines changed

7 files changed

+859
-2
lines changed

3rdparty/tvm

Submodule tvm updated 123 files

python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def create_flashinfer_paged_kv_cache(
179179
if ( # pylint: disable=too-many-boolean-expressions
180180
not self.flashinfer
181181
or self.target.kind.name != "cuda"
182-
or str(kwargs["dtype"]) not in ["float16"]
182+
or str(kwargs["dtype"]) not in ["float16", "bfloat16"]
183183
or (
184184
kwargs["rope_mode"] == RopeMode.INLINE
185185
and (

python/mlc_llm/model/gemma3/__init__.py

Whitespace-only changes.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""
2+
This file specifies how MLC's Gemma3 parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .gemma3_model import Gemma3Config, Gemma3ForCausalLM
14+
15+
16+
def huggingface(model_config: Gemma3Config, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : Gemma3Config
23+
The configuration of the Gemma model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = Gemma3ForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
mlc_prefix = "language_model."
45+
hf_prefix = "language_model." if not model_config.is_text_model else ""
46+
for i in range(model_config.text_config.num_hidden_layers):
47+
# Add gates in MLP
48+
mlp = f"model.layers.{i}.mlp"
49+
mlc_name = f"{mlc_prefix + mlp}.gate_up_proj.weight"
50+
mlc_param = named_parameters[mlc_name]
51+
mapping.add_mapping(
52+
mlc_name,
53+
[
54+
f"{hf_prefix + mlp}.gate_proj.weight",
55+
f"{hf_prefix + mlp}.up_proj.weight",
56+
],
57+
functools.partial(
58+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
59+
dtype=mlc_param.dtype,
60+
),
61+
)
62+
# Modify RMS layernorm weights, since Gemma model adds 1 to the weights
63+
# We add 1 to the weights here for efficiency purpose
64+
mlc_name = f"model.layers.{i}.input_layernorm.weight"
65+
mlc_param = named_parameters[mlc_prefix + mlc_name]
66+
mapping.add_mapping(
67+
mlc_prefix + mlc_name,
68+
[hf_prefix + mlc_name],
69+
functools.partial(
70+
lambda x, dtype: (x + 1).astype(dtype),
71+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
72+
),
73+
)
74+
75+
mlc_name = f"model.layers.{i}.post_attention_layernorm.weight"
76+
mlc_param = named_parameters[mlc_prefix + mlc_name]
77+
mapping.add_mapping(
78+
mlc_prefix + mlc_name,
79+
[hf_prefix + mlc_name],
80+
functools.partial(
81+
lambda x, dtype: (x + 1).astype(dtype),
82+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
83+
),
84+
)
85+
86+
mlc_name = f"model.layers.{i}.pre_feedforward_layernorm.weight"
87+
mlc_param = named_parameters[mlc_prefix + mlc_name]
88+
mapping.add_mapping(
89+
mlc_prefix + mlc_name,
90+
[hf_prefix + mlc_name],
91+
functools.partial(
92+
lambda x, dtype: (x + 1).astype(dtype),
93+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
94+
),
95+
)
96+
97+
mlc_name = f"model.layers.{i}.post_feedforward_layernorm.weight"
98+
mlc_param = named_parameters[mlc_prefix + mlc_name]
99+
mapping.add_mapping(
100+
mlc_prefix + mlc_name,
101+
[hf_prefix + mlc_name],
102+
functools.partial(
103+
lambda x, dtype: (x + 1).astype(dtype),
104+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
105+
),
106+
)
107+
108+
mlc_name = f"model.layers.{i}.self_attn.k_norm.weight"
109+
mlc_param = named_parameters[mlc_prefix + mlc_name]
110+
mapping.add_mapping(
111+
mlc_prefix + mlc_name,
112+
[hf_prefix + mlc_name],
113+
functools.partial(
114+
lambda x, dtype: (x + 1).astype(dtype),
115+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
116+
),
117+
)
118+
119+
mlc_name = f"model.layers.{i}.self_attn.q_norm.weight"
120+
mlc_param = named_parameters[mlc_prefix + mlc_name]
121+
mapping.add_mapping(
122+
mlc_prefix + mlc_name,
123+
[hf_prefix + mlc_name],
124+
functools.partial(
125+
lambda x, dtype: (x + 1).astype(dtype),
126+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
127+
),
128+
)
129+
130+
mlc_name = "model.norm.weight"
131+
mlc_param = named_parameters[mlc_prefix + mlc_name]
132+
mapping.add_mapping(
133+
mlc_prefix + mlc_name,
134+
[hf_prefix + mlc_name],
135+
functools.partial(
136+
lambda x, dtype: (x + 1).astype(dtype),
137+
dtype=named_parameters[mlc_prefix + mlc_name].dtype,
138+
),
139+
)
140+
141+
for mlc_name, mlc_param in named_parameters.items():
142+
if mlc_name not in mapping.param_map:
143+
mapping.add_mapping(
144+
mlc_name,
145+
[hf_prefix + mlc_name[len(mlc_prefix) :]],
146+
functools.partial(
147+
lambda x, dtype: x.astype(dtype),
148+
dtype=mlc_param.dtype,
149+
),
150+
)
151+
return mapping

0 commit comments

Comments
 (0)