Skip to content

Commit d638b9e

Browse files
zichongli5Zichong LiIsotr0py
authored andcommitted
[Model] Adds support for SlimMoE models Phi-tiny-MoE-instruct (vllm-project#20286)
Signed-off-by: Zichong Li <t-lizichong@microsoft.com@Reasoning-H100-VM3.drbuo4tcjzruhloch3eo0b25ef.cx.internal.cloudapp.net> Co-authored-by: Zichong Li <t-lizichong@microsoft.com@Reasoning-H100-VM3.drbuo4tcjzruhloch3eo0b25ef.cx.internal.cloudapp.net> Co-authored-by: Isotr0py <2037008807@qq.com>
1 parent 79a764b commit d638b9e

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

vllm/model_executor/models/phimoe.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
num_hidden_layers=32,
6969
num_attention_heads=32,
7070
num_key_value_heads=8,
71+
head_dim=None,
7172
hidden_act="silu",
7273
max_position_embeddings=4096 * 32,
7374
initializer_range=0.02,
@@ -101,8 +102,11 @@ def __init__(
101102
# for backward compatibility
102103
if num_key_value_heads is None:
103104
num_key_value_heads = num_attention_heads
105+
if head_dim is None:
106+
head_dim = hidden_size // num_attention_heads
104107

105108
self.num_key_value_heads = num_key_value_heads
109+
self.head_dim = head_dim
106110
self.hidden_act = hidden_act
107111
self.initializer_range = initializer_range
108112
self.rms_norm_eps = rms_norm_eps
@@ -294,6 +298,7 @@ def __init__(
294298
hidden_size: int,
295299
num_heads: int,
296300
num_kv_heads: int,
301+
head_dim: Optional[int] = None,
297302
max_position: int = 4096 * 32,
298303
rope_theta: float = 10000,
299304
cache_config: Optional[CacheConfig] = None,
@@ -317,7 +322,9 @@ def __init__(
317322
# the KV heads across multiple tensor parallel GPUs.
318323
assert tp_size % self.total_num_kv_heads == 0
319324
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
320-
self.head_dim = hidden_size // self.total_num_heads
325+
if head_dim is None:
326+
head_dim = hidden_size // num_heads
327+
self.head_dim = head_dim
321328
self.q_size = self.num_heads * self.head_dim
322329
self.kv_size = self.num_kv_heads * self.head_dim
323330
self.scaling = self.head_dim**-0.5
@@ -387,6 +394,8 @@ def __init__(
387394
num_heads=config.num_attention_heads,
388395
max_position=config.max_position_embeddings,
389396
num_kv_heads=config.num_key_value_heads,
397+
head_dim=getattr(config, "head_dim",
398+
self.hidden_size // config.num_attention_heads),
390399
rope_theta=rope_theta,
391400
cache_config=cache_config,
392401
quant_config=quant_config,

0 commit comments

Comments
 (0)