Skip to content

Commit 46f9bc4

Browse files
authored
[Model] Support Phi-4 (#3268)
This PR supports Phi-4 using the existing Phi3 architecture. Updated conversation template and model preset for phi-4-mini-instruct. Added support for tie_word_embeddings and partial_rotary factors. Fix rotary_dim usage in position_embedding.py from head_dim//2 to rotary_dim//2 for ext_factor.
1 parent 2ecc5f5 commit 46f9bc4

File tree

6 files changed

+196
-16
lines changed

6 files changed

+196
-16
lines changed

python/mlc_llm/conversation_template/phi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,20 @@
5151
stop_token_ids=[2, 32000, 32001, 32007],
5252
)
5353
)
54+
55+
# Phi-4
56+
ConvTemplateRegistry.register_conv_template(
57+
Conversation(
58+
name="phi-4",
59+
system_template=f"<|system|>\n{MessagePlaceholders.SYSTEM.value}",
60+
system_message="You are a helpful digital assistant. Please provide safe, "
61+
"ethical and accurate information to the user.",
62+
roles={"user": "<|user|>", "assistant": "<|assistant|>"},
63+
seps=["<|end|>\n"],
64+
role_content_sep="\n",
65+
role_empty_sep="\n",
66+
system_prefix_token_ids=[200022], # <|system|>
67+
stop_str=["<|endoftext|>", "<|end|>"],
68+
stop_token_ids=[199999, 200020], # <|endoftext|>, <|end|>
69+
)
70+
)

python/mlc_llm/interface/gen_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
299299
"phi-2",
300300
"phi-3",
301301
"phi-3-vision",
302+
"phi-4",
302303
"stablelm-2",
303304
"gemma_instruction",
304305
"gemma3_instruction",

python/mlc_llm/model/model_preset.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,148 @@
757757
"vocab_size": 32064,
758758
"_attn_implementation": "flash_attention_2",
759759
},
760+
"phi-4": {
761+
"_name_or_path": "Phi-4-mini-instruct",
762+
"architectures": ["Phi3ForCausalLM"],
763+
"attention_bias": False,
764+
"attention_dropout": 0.0,
765+
"auto_map": {
766+
"AutoConfig": "configuration_phi3.Phi3Config",
767+
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
768+
"AutoTokenizer": "Xenova/gpt-4o",
769+
},
770+
"bos_token_id": 199999,
771+
"embd_pdrop": 0.0,
772+
"eos_token_id": 199999,
773+
"full_attn_mod": 1,
774+
"hidden_act": "silu",
775+
"hidden_size": 3072,
776+
"initializer_range": 0.02,
777+
"intermediate_size": 8192,
778+
"interpolate_factor": 1,
779+
"lm_head_bias": False,
780+
"max_position_embeddings": 131072,
781+
"mlp_bias": False,
782+
"model_type": "phi3",
783+
"num_attention_heads": 24,
784+
"num_hidden_layers": 32,
785+
"num_key_value_heads": 8,
786+
"original_max_position_embeddings": 4096,
787+
"pad_token_id": 199999,
788+
"partial_rotary_factor": 0.75,
789+
"resid_pdrop": 0.0,
790+
"rms_norm_eps": 1e-05,
791+
"rope_scaling": {
792+
"long_factor": [
793+
1,
794+
1.118320672,
795+
1.250641126,
796+
1.398617824,
797+
1.564103225,
798+
1.74916897,
799+
1.956131817,
800+
2.187582649,
801+
2.446418898,
802+
2.735880826,
803+
3.059592084,
804+
3.421605075,
805+
3.826451687,
806+
4.279200023,
807+
4.785517845,
808+
5.351743533,
809+
5.984965424,
810+
6.693110555,
811+
7.485043894,
812+
8.370679318,
813+
9.36110372,
814+
10.4687158,
815+
11.70738129,
816+
13.09260651,
817+
14.64173252,
818+
16.37415215,
819+
18.31155283,
820+
20.47818807,
821+
22.90118105,
822+
25.61086418,
823+
28.64115884,
824+
32.03,
825+
32.1,
826+
32.13,
827+
32.23,
828+
32.6,
829+
32.61,
830+
32.64,
831+
32.66,
832+
32.7,
833+
32.71,
834+
32.93,
835+
32.97,
836+
33.28,
837+
33.49,
838+
33.5,
839+
44.16,
840+
47.77,
841+
],
842+
"short_factor": [
843+
1.0,
844+
1.0,
845+
1.0,
846+
1.0,
847+
1.0,
848+
1.0,
849+
1.0,
850+
1.0,
851+
1.0,
852+
1.0,
853+
1.0,
854+
1.0,
855+
1.0,
856+
1.0,
857+
1.0,
858+
1.0,
859+
1.0,
860+
1.0,
861+
1.0,
862+
1.0,
863+
1.0,
864+
1.0,
865+
1.0,
866+
1.0,
867+
1.0,
868+
1.0,
869+
1.0,
870+
1.0,
871+
1.0,
872+
1.0,
873+
1.0,
874+
1.0,
875+
1.0,
876+
1.0,
877+
1.0,
878+
1.0,
879+
1.0,
880+
1.0,
881+
1.0,
882+
1.0,
883+
1.0,
884+
1.0,
885+
1.0,
886+
1.0,
887+
1.0,
888+
1.0,
889+
1.0,
890+
1.0,
891+
],
892+
"type": "longrope",
893+
},
894+
"rope_theta": 10000.0,
895+
"sliding_window": 262144,
896+
"tie_word_embeddings": True,
897+
"torch_dtype": "bfloat16",
898+
"transformers_version": "4.45.0",
899+
"use_cache": True,
900+
"vocab_size": 200064,
901+
},
760902
"qwen": {
761903
"architectures": ["QWenLMHeadModel"],
762904
"auto_map": {

python/mlc_llm/model/phi3/phi3_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def _add(mlc_name, hf_name):
4848
),
4949
)
5050

51-
_add("lm_head.weight", "lm_head.weight")
51+
# Skip lm_head.weight if tie_word_embeddings is enabled
52+
if not getattr(model_config, "tie_word_embeddings", False):
53+
_add("lm_head.weight", "lm_head.weight")
5254
_add("transformer.norm.weight", "model.norm.weight")
5355
_add("transformer.embd.weight", "model.embed_tokens.weight")
5456

python/mlc_llm/model/phi3/phi3_model.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class Phi3Config(ConfigBase): # pylint: disable=too-many-instance-attributes
4040
head_dim: int = 0
4141
tensor_parallel_shards: int = 1
4242
max_batch_size: int = 1
43+
tie_word_embeddings: bool = False
44+
partial_rotary_factor: float = 1.0
4345
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
4446

4547
def __post_init__(self):
@@ -94,6 +96,17 @@ def __post_init__(self):
9496
# pylint: disable=invalid-name,missing-docstring
9597

9698

99+
class Phi3Embedding(nn.Embedding):
100+
"""The embedding module that can be shared with the final lm_head."""
101+
102+
def lm_head_forward(self, x: nn.Tensor):
103+
"""The lm_head forwarding, which transposes the weight and multiplies
104+
with the input tensor.
105+
"""
106+
weight = nn.op.permute_dims(self.weight)
107+
return nn.op.matmul(x, weight, out_dtype="float32")
108+
109+
97110
class Phi3MLP(nn.Module):
98111
def __init__(self, config: Phi3Config):
99112
super().__init__()
@@ -195,7 +208,7 @@ def _apply_parallel_residual(self, mlp_out, residual):
195208
class Phi3Model(nn.Module):
196209
def __init__(self, config: Phi3Config) -> None:
197210
super().__init__()
198-
self.embd = nn.Embedding(config.vocab_size, config.hidden_size)
211+
self.embd = Phi3Embedding(config.vocab_size, config.hidden_size)
199212
self.h = nn.ModuleList([Phi3ParallelBlock(config) for _ in range(config.num_hidden_layers)])
200213
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
201214

@@ -213,7 +226,9 @@ def __init__(self, config: Phi3Config) -> None:
213226
super().__init__()
214227

215228
self.transformer = Phi3Model(config)
216-
self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False)
229+
self.tie_word_embeddings = config.tie_word_embeddings
230+
if not config.tie_word_embeddings:
231+
self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False)
217232
self.num_hidden_layers = config.num_hidden_layers
218233
self.num_attention_heads = config.num_attention_heads
219234
self.num_key_value_heads = config.num_key_value_heads
@@ -226,13 +241,24 @@ def __init__(self, config: Phi3Config) -> None:
226241
config.rope_scaling["long_factor"] if config.rope_scaling is not None else None
227242
)
228243
self.tensor_parallel_shards = config.tensor_parallel_shards
244+
self.partial_rotary_factor = config.partial_rotary_factor
229245
self.dtype = "float32"
230246

231247
def to(self, dtype: Optional[str] = None):
232248
super().to(dtype=dtype)
233249
if dtype is not None:
234250
self.dtype = dtype
235251

252+
def get_logits(self, hidden_states: Tensor):
253+
op_ext.configure()
254+
if self.tie_word_embeddings:
255+
logits = self.transformer.embd.lm_head_forward(hidden_states)
256+
else:
257+
logits = self.lm_head(hidden_states)
258+
if logits.dtype != "float32":
259+
logits = logits.astype("float32")
260+
return logits
261+
236262
def batch_forward(
237263
self,
238264
input_embeds: Tensor,
@@ -244,10 +270,7 @@ def batch_forward(
244270
hidden_states = self.transformer(input_embeds, paged_kv_cache)
245271
if logit_positions is not None:
246272
hidden_states = op.take(hidden_states, logit_positions, axis=1)
247-
lm_logits = self.lm_head(hidden_states)
248-
if lm_logits.dtype != "float32":
249-
lm_logits = lm_logits.astype("float32")
250-
return lm_logits
273+
return self.get_logits(hidden_states)
251274

252275
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
253276
op_ext.configure()
@@ -258,20 +281,14 @@ def _index(x: te.Tensor):
258281

259282
hidden_states = self.transformer(input_embed, paged_kv_cache)
260283
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
261-
logits = self.lm_head(hidden_states)
262-
263-
if logits.dtype != "float32":
264-
logits = logits.astype("float32")
265-
284+
logits = self.get_logits(hidden_states)
266285
return logits, paged_kv_cache
267286

268287
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
269288
op_ext.configure()
270289

271290
hidden_states = self.transformer(input_embed, paged_kv_cache)
272-
logits = self.lm_head(hidden_states)
273-
if logits.dtype != "float32":
274-
logits = logits.astype("float32")
291+
logits = self.get_logits(hidden_states)
275292
return logits, paged_kv_cache
276293

277294
def batch_prefill(
@@ -321,6 +338,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
321338
rope_scale=1,
322339
rope_theta=self.rope_theta,
323340
rope_ext_factors=self.rope_ext_factors,
341+
rotary_dim=int(self.head_dim * self.partial_rotary_factor),
324342
dtype=self.dtype,
325343
)
326344

python/mlc_llm/op/position_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
468468
var_q: T.handle,
469469
var_k: T.handle,
470470
var_v: T.handle,
471-
ext_factors: T.Buffer((head_dim // 2,), "float32"), # type: ignore
471+
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
472472
):
473473
T.func_attr(
474474
{

0 commit comments

Comments
 (0)