Skip to content

Commit 94a0295

Browse files
authored
[Model] Support Multi-GPU for Qwen-MoE model (#2573)
This PR introduces the multi-GPU support for the Qwen-MoE model. Validated on 4090x2.
1 parent 07c92b0 commit 94a0295

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
from mlc_llm.nn import PagedKVCache, RopeMode
1515
from mlc_llm.nn.expert import MixtralExperts
1616
from mlc_llm.support import logging
17+
from mlc_llm.support import tensor_parallel as tp
1718

1819
logger = logging.getLogger(__name__)
1920

20-
# TODO(mlc-team): Support Tensor Parallel.
21-
2221

2322
@dataclasses.dataclass
2423
class Qwen2MoeConfig(QWen2Config): # pylint: disable=too-many-instance-attributes
@@ -68,10 +67,7 @@ def __init__(self, config: Qwen2MoeConfig):
6867
)
6968
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
7069
self.norm_topk_prob = config.norm_topk_prob
71-
self.share_expert_intermediate_size = (
72-
config.shared_expert_intermediate_size // config.tensor_parallel_shards
73-
)
74-
self.shared_expert = Qwen2MoeMLP(config, self.share_expert_intermediate_size)
70+
self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size)
7571
self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)
7672

7773
self.gate = nn.Linear(
@@ -154,7 +150,42 @@ def __init__(self, config: Qwen2MoeConfig):
154150
self.post_attention_layernorm = nn.RMSNorm(
155151
config.hidden_size, -1, config.rms_norm_eps, bias=False
156152
)
153+
154+
def _set_tp():
155+
def _set(layer, hint):
156+
layer.attrs["shard_strategy"] = hint
157+
158+
hd = config.head_dim
159+
q = self.self_attn.num_attention_heads * hd
160+
k = self.self_attn.num_key_value_heads * hd
161+
v = self.self_attn.num_key_value_heads * hd
162+
si = self.mlp.shared_expert.intermediate_size
163+
mi = self.mlp.moe_intermediate_size
164+
_set(
165+
self.self_attn.c_attn.weight,
166+
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
167+
)
168+
_set(
169+
self.self_attn.c_attn.bias,
170+
tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]),
171+
)
172+
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
173+
_set(
174+
self.mlp.shared_expert.gate_up_proj.weight,
175+
tp.ShardSingleDim("_shard_shared_mlp_up", segs=[si, si], dim=0),
176+
)
177+
_set(
178+
self.mlp.shared_expert.down_proj.weight,
179+
tp.ShardSingleDim("_shard_shared_mlp_down", dim=1),
180+
)
181+
_set(
182+
self.mlp.moe_gate_up_proj.weight,
183+
tp.ShardSingleDim("_shard_moe_mlp_up", segs=[mi, mi], dim=1),
184+
)
185+
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2))
186+
157187
self.tensor_parallel_shards = config.tensor_parallel_shards
188+
_set_tp()
158189

159190
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
160191
out = self.input_layernorm(hidden_states)
@@ -202,8 +233,6 @@ def __init__(self, config: Qwen2MoeConfig):
202233
self.vocab_size = config.vocab_size
203234
self.tensor_parallel_shards = config.tensor_parallel_shards
204235
self.head_dim = config.head_dim
205-
if self.tensor_parallel_shards != 1:
206-
raise ValueError("Currently only support tensor_parallel_shards=1.")
207236

208237
def to(self, dtype: Optional[str] = None):
209238
super().to(dtype=dtype)

0 commit comments

Comments
 (0)