|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn.functional as F |
| 9 | +from torch import nn |
| 10 | +from torchtitan.models.llama3 import model |
| 11 | + |
| 12 | +from .args import DeepseekV3ModelArgs |
| 13 | + |
| 14 | + |
| 15 | +# Reference: torchtitan/experiments/llama4/model/ |
| 16 | +class GroupedExperts(nn.Module): |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + dim: int, |
| 20 | + hidden_dim: int, |
| 21 | + num_experts: int, |
| 22 | + use_grouped_mm: bool, |
| 23 | + ): |
| 24 | + super().__init__() |
| 25 | + self.num_experts = num_experts |
| 26 | + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) |
| 27 | + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) |
| 28 | + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) |
| 29 | + self.use_grouped_mm = use_grouped_mm |
| 30 | + |
| 31 | + def forward( |
| 32 | + self, |
| 33 | + x: torch.Tensor, |
| 34 | + num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, |
| 35 | + ) -> torch.Tensor: |
| 36 | + # TODO: keeping this for loop implementation for comparison |
| 37 | + # and readability, will remove later |
| 38 | + if not self.use_grouped_mm: |
| 39 | + if num_local_tokens_per_expert is not None: |
| 40 | + # a tuple of tensors indexed by experts |
| 41 | + # each with shape (tokens_per_expert(varying), dim) |
| 42 | + x = torch.split( |
| 43 | + x, |
| 44 | + split_size_or_sections=num_local_tokens_per_expert, |
| 45 | + dim=0, |
| 46 | + ) |
| 47 | + out_experts_splits = [] |
| 48 | + for expert_idx, x_expert in enumerate(x): |
| 49 | + w1, w2, w3 = ( |
| 50 | + self.w1[expert_idx], |
| 51 | + self.w2[expert_idx], |
| 52 | + self.w3[expert_idx], |
| 53 | + ) |
| 54 | + h = F.silu(torch.matmul(x_expert, w1)) |
| 55 | + h = h * torch.matmul(x_expert, w3) |
| 56 | + h = torch.matmul(h, w2) |
| 57 | + # h shape (tokens_per_expert(varying), dim) |
| 58 | + out_experts_splits.append(h) |
| 59 | + out = torch.cat(out_experts_splits, dim=0) |
| 60 | + else: |
| 61 | + # x shape (num_experts, tokens_per_expert, dim) |
| 62 | + h = F.silu(torch.bmm(x, self.w1)) |
| 63 | + h = h * torch.bmm(x, self.w3) |
| 64 | + # out shape (num_experts, tokens_per_expert, dim) |
| 65 | + out = torch.bmm(h, self.w2) |
| 66 | + |
| 67 | + return out |
| 68 | + |
| 69 | + # grouped mm implementation |
| 70 | + if num_local_tokens_per_expert is not None: |
| 71 | + # https://github.com/pytorch/pytorch/pull/150374 |
| 72 | + # NOTE: torch._gouped_mm requires bf16 dtypes |
| 73 | + # and shapes to be multiple of 8 |
| 74 | + offsets = torch.cumsum( |
| 75 | + num_local_tokens_per_expert, dim=0, dtype=torch.int32 |
| 76 | + ) |
| 77 | + # grouped mm between a 2D tensor and a 3D tensor |
| 78 | + assert x.dim() == 2 |
| 79 | + else: |
| 80 | + offsets = None |
| 81 | + # fall back to regular bmm between 3D tensors |
| 82 | + assert x.dim() == 3 |
| 83 | + |
| 84 | + assert ( |
| 85 | + x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 |
| 86 | + ), "torch._grouped_mm only supports bf16 dtypes" |
| 87 | + h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) |
| 88 | + h = h * torch._grouped_mm(x, self.w3, offs=offsets) |
| 89 | + out = torch._grouped_mm(h, self.w2, offs=offsets) |
| 90 | + |
| 91 | + return out |
| 92 | + |
| 93 | + |
| 94 | +class TokenChoiceTopKRouter(nn.Module): |
| 95 | + """This class implements token-choice routing. In token-choice top-K routing, each token is |
| 96 | + routed to top K experts based on the router scores. |
| 97 | +
|
| 98 | + Args: |
| 99 | + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). |
| 100 | + num_experts (int): Number of experts in each moe layer. |
| 101 | + top_k (int): Number of experts each token will be routed to in token-choice routing. |
| 102 | + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__( |
| 106 | + self, |
| 107 | + num_experts: int, |
| 108 | + top_k: int, |
| 109 | + use_sigmoid: bool = False, |
| 110 | + route_sclaing_factor: float = 1.0, |
| 111 | + ): |
| 112 | + super().__init__() |
| 113 | + |
| 114 | + self.num_experts = num_experts |
| 115 | + self.top_k = top_k |
| 116 | + self.use_sigmoid = use_sigmoid |
| 117 | + self.route_sclaing_factor |
| 118 | + |
| 119 | + self.weight = nn.Parameter( |
| 120 | + torch.empty((self.n_routed_experts, self.gating_dim)) |
| 121 | + ) |
| 122 | + # TODO: is this needed? This is not "Complementary Sequence-Wise Auxiliary Loss" |
| 123 | + # self.e_score_correction_bias = nn.Parameter(torch.rand((self.num_experts))) |
| 124 | + |
| 125 | + def forward( |
| 126 | + self, x: torch.Tensor, expert_bias: torch.Tensor = None |
| 127 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 128 | + """ |
| 129 | + TODO: We haven't implement the group-based routing (node limit routing) yet, and currently EP is not supporting node limit routing yet. |
| 130 | + Args: |
| 131 | + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + routed_input (torch.Tensor): |
| 135 | + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. |
| 136 | + token_indices (torch.Tensor): |
| 137 | + Token indices for routed_input with shape ``(bs*slen*top_k,)``. |
| 138 | + num_local_tokens_per_expert (torch.Tensor): |
| 139 | + Number of tokens assigned to each expert with shape ``(num_experts,)``. |
| 140 | + """ |
| 141 | + # scores shape (bs*slen, num_experts) |
| 142 | + scores = F.linear(x.type, self.weight, None) |
| 143 | + |
| 144 | + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion |
| 145 | + if self.use_sigmoid: |
| 146 | + scores = torch.sigmoid(scores.to(torch.float32)) |
| 147 | + else: |
| 148 | + scores = F.softmax(scores.to(torch.float32), dim=1) |
| 149 | + |
| 150 | + # top scores shape (bs*slen, top_k) |
| 151 | + # NOTE: The expert_bias is only used for routing. The gating value |
| 152 | + # top_scores is still derived from the original scores. |
| 153 | + _, selected_experts_indices = torch.topk( |
| 154 | + scores + expert_bias, k=self.top_k, dim=1 |
| 155 | + ) |
| 156 | + top_scores = scores.gather(dim=1, index=selected_experts_indices) |
| 157 | + |
| 158 | + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward |
| 159 | + num_local_tokens_per_expert = torch.histc( |
| 160 | + selected_experts_indices.view(-1), |
| 161 | + bins=self.num_experts, |
| 162 | + min=0, |
| 163 | + max=self.num_experts, |
| 164 | + ) |
| 165 | + # token_indices_experts_sorted shape (bs*slen*top_k,) |
| 166 | + token_indices_experts_sorted = torch.argsort( |
| 167 | + selected_experts_indices.view(-1), stable=True |
| 168 | + ) |
| 169 | + top_scores = top_scores.view(-1)[token_indices_experts_sorted] |
| 170 | + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k |
| 171 | + |
| 172 | + top_scores = ( |
| 173 | + top_scores * self.route_sclaing_factor |
| 174 | + ) # must multiply the scaling factor |
| 175 | + |
| 176 | + return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert |
| 177 | + |
| 178 | + |
| 179 | +class MoE(nn.Module): |
| 180 | + def __init__(self, model_args: DeepseekV3ModelArgs): |
| 181 | + |
| 182 | + # n_routed_experts: int = 64 |
| 183 | + # n_shared_experts: int = 2 |
| 184 | + # n_activated_experts: int = 6 |
| 185 | + # score_func: Literal["softmax", "sigmoid"] = "softmax" |
| 186 | + # route_scale: float = 1.0 |
| 187 | + |
| 188 | + super().__init__() |
| 189 | + dim = model_args.dim |
| 190 | + |
| 191 | + num_experts = model_args.n_routed_experts |
| 192 | + hidden_dim = model_args.moe_inter_dim |
| 193 | + top_k = model_args.n_activated_experts |
| 194 | + route_scaling_factor = model_args.route_scale |
| 195 | + |
| 196 | + self.use_grouped_mm = model_args.use_grouped_mm |
| 197 | + self.experts = GroupedExperts( |
| 198 | + dim=dim, |
| 199 | + hidden_dim=hidden_dim, |
| 200 | + num_experts=num_experts, |
| 201 | + use_grouped_mm=self.use_grouped_mm, |
| 202 | + ) |
| 203 | + self.router = TokenChoiceTopKRouter( |
| 204 | + num_experts=num_experts, |
| 205 | + top_k=top_k, |
| 206 | + use_sigmoid=model_args.score_func == "sigmoid", |
| 207 | + route_sclaing_factor=route_scaling_factor, |
| 208 | + ) |
| 209 | + self.shared_expert = ( |
| 210 | + # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 |
| 211 | + GroupedExperts( |
| 212 | + dim=dim, |
| 213 | + hidden_dim=hidden_dim * model_args.n_shared_experts, |
| 214 | + num_experts=1, |
| 215 | + use_grouped_mm=self.use_grouped_mm, |
| 216 | + ) |
| 217 | + if model_args.n_shared_experts > 0 |
| 218 | + else None |
| 219 | + ) |
| 220 | + |
| 221 | + # auxiliary-loss-free load balancing |
| 222 | + self.load_balance_coeff = model_args.load_balance_coeff |
| 223 | + # the fields below are defined even when load_balance_coeff is None |
| 224 | + # to make initialization and checkpointing code simpler |
| 225 | + self.register_buffer( |
| 226 | + "expert_bias", |
| 227 | + torch.zeros(num_experts, dtype=torch.float32), |
| 228 | + persistent=True, |
| 229 | + ) |
| 230 | + self.register_buffer( |
| 231 | + "tokens_per_expert", |
| 232 | + torch.zeros(num_experts, dtype=torch.float32), |
| 233 | + persistent=True, |
| 234 | + ) |
| 235 | + |
| 236 | + # NOTE: forward hook, forward pre hook, or backward pre hook |
| 237 | + # would conflict with activation checkpointing |
| 238 | + if self.load_balance_coeff is not None and self.load_balance_coeff > 0: |
| 239 | + self.register_full_backward_hook(self._update_expert_bias) |
| 240 | + |
| 241 | + # TODO: double check the bias update logic. It aligns with the paper. |
| 242 | + def _update_expert_bias(self, *_): |
| 243 | + expert_bias_delta = self.load_balance_coeff * torch.sign( |
| 244 | + self.tokens_per_expert.mean() - self.tokens_per_expert |
| 245 | + ) |
| 246 | + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() |
| 247 | + self.expert_bias.add_(expert_bias_delta) |
| 248 | + |
| 249 | + self.tokens_per_expert.zero_() |
| 250 | + |
| 251 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 252 | + """ |
| 253 | + Args: |
| 254 | + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. |
| 255 | +
|
| 256 | + Returns: |
| 257 | + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. |
| 258 | + """ |
| 259 | + bs, slen, dim = x.shape |
| 260 | + |
| 261 | + # top_scores and selected_indices shape (bs*slen*top_k,) |
| 262 | + # num_local_tokens_per_expert shape (num_experts,) |
| 263 | + ( |
| 264 | + top_scores, |
| 265 | + token_indices, |
| 266 | + num_local_tokens_per_expert, |
| 267 | + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) |
| 268 | + |
| 269 | + # will be used to update the expert bias for load balancing |
| 270 | + self.tokens_per_expert += num_local_tokens_per_expert |
| 271 | + |
| 272 | + # shape (bs*slen*top_k, dim) |
| 273 | + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) |
| 274 | + |
| 275 | + # shape (bs*slen*top_k, dim) |
| 276 | + routed_input = torch.gather( |
| 277 | + x.view(-1, dim), |
| 278 | + dim=0, |
| 279 | + index=token_indices, |
| 280 | + ) |
| 281 | + |
| 282 | + if self.use_grouped_mm: |
| 283 | + # NOTE: In order to use torch._grouped_mm, we need to make sure |
| 284 | + # the number of tokens each expert gets is a multiple of 16. |
| 285 | + # The following kernel helps achieve this via padding, without |
| 286 | + # incurring synchronization between device and host. |
| 287 | + from torchtitan.experiments.kernels.moe.indices import ( |
| 288 | + generate_permute_indices, |
| 289 | + ) |
| 290 | + |
| 291 | + ALIGN_SIZE_M = 16 |
| 292 | + |
| 293 | + with torch.no_grad(): |
| 294 | + ( |
| 295 | + permuted_indices, |
| 296 | + num_local_tokens_per_expert, |
| 297 | + _, |
| 298 | + ) = generate_permute_indices( |
| 299 | + num_local_tokens_per_expert, |
| 300 | + self.experts.num_experts, |
| 301 | + 1, |
| 302 | + ALIGN_SIZE_M, |
| 303 | + ) |
| 304 | + token_indices = torch.vstack( |
| 305 | + (token_indices, token_indices.new_zeros((dim))) |
| 306 | + ) |
| 307 | + token_indices = token_indices[permuted_indices, :] |
| 308 | + routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) |
| 309 | + routed_input = routed_input[permuted_indices, :] |
| 310 | + else: |
| 311 | + # NOTE: this would incur a synchronization between device and host |
| 312 | + num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() |
| 313 | + |
| 314 | + # shape (bs*slen*top_k, dim) |
| 315 | + routed_output = self.experts(routed_input, num_local_tokens_per_expert) |
| 316 | + routed_output = routed_output * top_scores.unsqueeze(-1) |
| 317 | + |
| 318 | + # shared expert |
| 319 | + if self.shared_expert is not None: |
| 320 | + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( |
| 321 | + bs * slen, dim |
| 322 | + ) |
| 323 | + else: |
| 324 | + out = torch.zeros_like(x.reshape(bs * slen, dim)) |
| 325 | + |
| 326 | + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) |
| 327 | + out = out.reshape(bs, slen, dim) |
| 328 | + return out |
| 329 | + |
| 330 | + def init_weights( |
| 331 | + self, |
| 332 | + init_std: float, |
| 333 | + buffer_device: torch.device, |
| 334 | + ): |
| 335 | + self.experts.init_weights(init_std) |
| 336 | + self.router.init_weights(init_std) |
| 337 | + if self.shared_expert is not None: |
| 338 | + self.shared_expert.init_weights(init_std) |
| 339 | + |
| 340 | + with torch.device(buffer_device): |
| 341 | + self.expert_bias = torch.zeros( |
| 342 | + self.experts.num_experts, dtype=torch.float32 |
| 343 | + ) |
| 344 | + self.tokens_per_expert = torch.zeros( |
| 345 | + self.experts.num_experts, dtype=torch.float32 |
| 346 | + ) |
0 commit comments