|
14 | 14 | from mlc_llm.nn import PagedKVCache, RopeMode
|
15 | 15 | from mlc_llm.nn.expert import MixtralExperts
|
16 | 16 | from mlc_llm.support import logging
|
| 17 | +from mlc_llm.support import tensor_parallel as tp |
17 | 18 |
|
18 | 19 | logger = logging.getLogger(__name__)
|
19 | 20 |
|
20 |
| -# TODO(mlc-team): Support Tensor Parallel. |
21 |
| - |
22 | 21 |
|
23 | 22 | @dataclasses.dataclass
|
24 | 23 | class Qwen2MoeConfig(QWen2Config): # pylint: disable=too-many-instance-attributes
|
@@ -68,10 +67,7 @@ def __init__(self, config: Qwen2MoeConfig):
|
68 | 67 | )
|
69 | 68 | self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
|
70 | 69 | self.norm_topk_prob = config.norm_topk_prob
|
71 |
| - self.share_expert_intermediate_size = ( |
72 |
| - config.shared_expert_intermediate_size // config.tensor_parallel_shards |
73 |
| - ) |
74 |
| - self.shared_expert = Qwen2MoeMLP(config, self.share_expert_intermediate_size) |
| 70 | + self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size) |
75 | 71 | self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)
|
76 | 72 |
|
77 | 73 | self.gate = nn.Linear(
|
@@ -154,7 +150,42 @@ def __init__(self, config: Qwen2MoeConfig):
|
154 | 150 | self.post_attention_layernorm = nn.RMSNorm(
|
155 | 151 | config.hidden_size, -1, config.rms_norm_eps, bias=False
|
156 | 152 | )
|
| 153 | + |
| 154 | + def _set_tp(): |
| 155 | + def _set(layer, hint): |
| 156 | + layer.attrs["shard_strategy"] = hint |
| 157 | + |
| 158 | + hd = config.head_dim |
| 159 | + q = self.self_attn.num_attention_heads * hd |
| 160 | + k = self.self_attn.num_key_value_heads * hd |
| 161 | + v = self.self_attn.num_key_value_heads * hd |
| 162 | + si = self.mlp.shared_expert.intermediate_size |
| 163 | + mi = self.mlp.moe_intermediate_size |
| 164 | + _set( |
| 165 | + self.self_attn.c_attn.weight, |
| 166 | + tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]), |
| 167 | + ) |
| 168 | + _set( |
| 169 | + self.self_attn.c_attn.bias, |
| 170 | + tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]), |
| 171 | + ) |
| 172 | + _set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1)) |
| 173 | + _set( |
| 174 | + self.mlp.shared_expert.gate_up_proj.weight, |
| 175 | + tp.ShardSingleDim("_shard_shared_mlp_up", segs=[si, si], dim=0), |
| 176 | + ) |
| 177 | + _set( |
| 178 | + self.mlp.shared_expert.down_proj.weight, |
| 179 | + tp.ShardSingleDim("_shard_shared_mlp_down", dim=1), |
| 180 | + ) |
| 181 | + _set( |
| 182 | + self.mlp.moe_gate_up_proj.weight, |
| 183 | + tp.ShardSingleDim("_shard_moe_mlp_up", segs=[mi, mi], dim=1), |
| 184 | + ) |
| 185 | + _set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2)) |
| 186 | + |
157 | 187 | self.tensor_parallel_shards = config.tensor_parallel_shards
|
| 188 | + _set_tp() |
158 | 189 |
|
159 | 190 | def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
|
160 | 191 | out = self.input_layernorm(hidden_states)
|
@@ -202,8 +233,6 @@ def __init__(self, config: Qwen2MoeConfig):
|
202 | 233 | self.vocab_size = config.vocab_size
|
203 | 234 | self.tensor_parallel_shards = config.tensor_parallel_shards
|
204 | 235 | self.head_dim = config.head_dim
|
205 |
| - if self.tensor_parallel_shards != 1: |
206 |
| - raise ValueError("Currently only support tensor_parallel_shards=1.") |
207 | 236 |
|
208 | 237 | def to(self, dtype: Optional[str] = None):
|
209 | 238 | super().to(dtype=dtype)
|
|
0 commit comments