Skip to content

Commit 4d54240

Browse files
authored
[Feature]:Allow for Granite MoE Hybrid models with _only_ shared experts. (#19652)
Signed-off-by: Shawn Tan <shawntan@ibm.com>
1 parent 3e75069 commit 4d54240

File tree

1 file changed

+40
-24
lines changed

1 file changed

+40
-24
lines changed

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ def __init__(self,
6767
activation=config.hidden_act,
6868
quant_config=quant_config)
6969

70-
self.block_sparse_moe = GraniteMoeMoE(
71-
num_experts=config.num_local_experts,
72-
top_k=config.num_experts_per_tok,
73-
hidden_size=config.hidden_size,
74-
intermediate_size=config.intermediate_size,
75-
quant_config=quant_config,
76-
prefix=f"{prefix}.block_sparse_moe")
70+
self.block_sparse_moe = None
71+
if getattr(config, "num_local_experts", 0) > 0:
72+
self.block_sparse_moe = GraniteMoeMoE(
73+
num_experts=config.num_local_experts,
74+
top_k=config.num_experts_per_tok,
75+
hidden_size=config.hidden_size,
76+
intermediate_size=config.intermediate_size,
77+
quant_config=quant_config,
78+
prefix=f"{prefix}.block_sparse_moe")
7779

7880
self.shared_mlp = None if \
7981
getattr(config, 'shared_intermediate_size', 0) == 0 \
@@ -105,13 +107,19 @@ def forward(
105107
residual = hidden_states
106108
hidden_states = self.post_attention_layernorm(hidden_states)
107109
if self.shared_mlp is None:
108-
hidden_states = self.block_sparse_moe(hidden_states)
110+
if self.block_sparse_moe is not None:
111+
hidden_states = self.block_sparse_moe(hidden_states)
112+
# else: skip
109113
else:
110114
# create a copy since block_sparse_moe modifies in-place
111-
moe_hidden_states = hidden_states.clone()
112-
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
113-
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
114-
del moe_hidden_states
115+
if self.block_sparse_moe is not None:
116+
moe_hidden_states = hidden_states.clone()
117+
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
118+
hidden_states = moe_hidden_states + self.shared_mlp(
119+
hidden_states)
120+
del moe_hidden_states
121+
else:
122+
hidden_states = self.shared_mlp(hidden_states)
115123
hidden_states = residual + hidden_states * self.residual_multiplier
116124

117125
return hidden_states, residual
@@ -137,13 +145,15 @@ def __init__(
137145
quant_config=quant_config,
138146
prefix=f"{prefix}.self_attn")
139147

140-
self.block_sparse_moe = GraniteMoeMoE(
141-
num_experts=config.num_local_experts,
142-
top_k=config.num_experts_per_tok,
143-
hidden_size=config.hidden_size,
144-
intermediate_size=config.intermediate_size,
145-
quant_config=quant_config,
146-
prefix=f"{prefix}.block_sparse_moe")
148+
self.block_sparse_moe = None
149+
if getattr(config, "num_local_experts", 0) > 0:
150+
self.block_sparse_moe = GraniteMoeMoE(
151+
num_experts=config.num_local_experts,
152+
top_k=config.num_experts_per_tok,
153+
hidden_size=config.hidden_size,
154+
intermediate_size=config.intermediate_size,
155+
quant_config=quant_config,
156+
prefix=f"{prefix}.block_sparse_moe")
147157

148158
self.shared_mlp = None if \
149159
getattr(config, 'shared_intermediate_size', 0) == 0 \
@@ -178,13 +188,19 @@ def forward(
178188
residual = hidden_states
179189
hidden_states = self.post_attention_layernorm(hidden_states)
180190
if self.shared_mlp is None:
181-
hidden_states = self.block_sparse_moe(hidden_states)
191+
if self.block_sparse_moe is not None:
192+
hidden_states = self.block_sparse_moe(hidden_states)
193+
# else: skip
182194
else:
183195
# create a copy since block_sparse_moe modifies in-place
184-
moe_hidden_states = hidden_states.clone()
185-
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
186-
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
187-
del moe_hidden_states
196+
if self.block_sparse_moe is not None:
197+
moe_hidden_states = hidden_states.clone()
198+
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
199+
hidden_states = moe_hidden_states + self.shared_mlp(
200+
hidden_states)
201+
del moe_hidden_states
202+
else:
203+
hidden_states = self.shared_mlp(hidden_states)
188204
hidden_states = residual + hidden_states * self.residual_multiplier
189205

190206
return hidden_states, residual

0 commit comments

Comments
 (0)