Skip to content

Commit bd6e2b8

Browse files
committed
add MoE
1 parent e9f925f commit bd6e2b8

File tree

2 files changed

+350
-0
lines changed

2 files changed

+350
-0
lines changed

torchtitan/models/deepseek-v3/model/args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class DeepseekV3ModelArgs(BaseModelArgs):
4141
n_limited_groups (int): Number of limited groups for MoE routing.
4242
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
4343
route_scale (float): Scaling factor for routing scores.
44+
use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers.
45+
load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers.
4446
q_lora_rank (int): LoRA rank for query projections.
4547
kv_lora_rank (int): LoRA rank for key-value projections.
4648
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
@@ -73,6 +75,8 @@ class DeepseekV3ModelArgs(BaseModelArgs):
7375
n_limited_groups: int = 1
7476
score_func: Literal["softmax", "sigmoid"] = "softmax"
7577
route_scale: float = 1.0
78+
use_grouped_mm: bool = False
79+
load_balance_coeff: float | None = 1e-3
7680
# Multi-Head Latent Attention (MLA)
7781
q_lora_rank: int = 0
7882
kv_lora_rank: int = 512
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.nn.functional as F
9+
from torch import nn
10+
from torchtitan.models.llama3 import model
11+
12+
from .args import DeepseekV3ModelArgs
13+
14+
15+
# Reference: torchtitan/experiments/llama4/model/
16+
class GroupedExperts(nn.Module):
17+
def __init__(
18+
self,
19+
dim: int,
20+
hidden_dim: int,
21+
num_experts: int,
22+
use_grouped_mm: bool,
23+
):
24+
super().__init__()
25+
self.num_experts = num_experts
26+
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
27+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
28+
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
29+
self.use_grouped_mm = use_grouped_mm
30+
31+
def forward(
32+
self,
33+
x: torch.Tensor,
34+
num_local_tokens_per_expert: torch.Tensor | list[int] | None = None,
35+
) -> torch.Tensor:
36+
# TODO: keeping this for loop implementation for comparison
37+
# and readability, will remove later
38+
if not self.use_grouped_mm:
39+
if num_local_tokens_per_expert is not None:
40+
# a tuple of tensors indexed by experts
41+
# each with shape (tokens_per_expert(varying), dim)
42+
x = torch.split(
43+
x,
44+
split_size_or_sections=num_local_tokens_per_expert,
45+
dim=0,
46+
)
47+
out_experts_splits = []
48+
for expert_idx, x_expert in enumerate(x):
49+
w1, w2, w3 = (
50+
self.w1[expert_idx],
51+
self.w2[expert_idx],
52+
self.w3[expert_idx],
53+
)
54+
h = F.silu(torch.matmul(x_expert, w1))
55+
h = h * torch.matmul(x_expert, w3)
56+
h = torch.matmul(h, w2)
57+
# h shape (tokens_per_expert(varying), dim)
58+
out_experts_splits.append(h)
59+
out = torch.cat(out_experts_splits, dim=0)
60+
else:
61+
# x shape (num_experts, tokens_per_expert, dim)
62+
h = F.silu(torch.bmm(x, self.w1))
63+
h = h * torch.bmm(x, self.w3)
64+
# out shape (num_experts, tokens_per_expert, dim)
65+
out = torch.bmm(h, self.w2)
66+
67+
return out
68+
69+
# grouped mm implementation
70+
if num_local_tokens_per_expert is not None:
71+
# https://github.com/pytorch/pytorch/pull/150374
72+
# NOTE: torch._gouped_mm requires bf16 dtypes
73+
# and shapes to be multiple of 8
74+
offsets = torch.cumsum(
75+
num_local_tokens_per_expert, dim=0, dtype=torch.int32
76+
)
77+
# grouped mm between a 2D tensor and a 3D tensor
78+
assert x.dim() == 2
79+
else:
80+
offsets = None
81+
# fall back to regular bmm between 3D tensors
82+
assert x.dim() == 3
83+
84+
assert (
85+
x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
86+
), "torch._grouped_mm only supports bf16 dtypes"
87+
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
88+
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
89+
out = torch._grouped_mm(h, self.w2, offs=offsets)
90+
91+
return out
92+
93+
94+
class TokenChoiceTopKRouter(nn.Module):
95+
"""This class implements token-choice routing. In token-choice top-K routing, each token is
96+
routed to top K experts based on the router scores.
97+
98+
Args:
99+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
100+
num_experts (int): Number of experts in each moe layer.
101+
top_k (int): Number of experts each token will be routed to in token-choice routing.
102+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
103+
"""
104+
105+
def __init__(
106+
self,
107+
num_experts: int,
108+
top_k: int,
109+
use_sigmoid: bool = False,
110+
route_sclaing_factor: float = 1.0,
111+
):
112+
super().__init__()
113+
114+
self.num_experts = num_experts
115+
self.top_k = top_k
116+
self.use_sigmoid = use_sigmoid
117+
self.route_sclaing_factor
118+
119+
self.weight = nn.Parameter(
120+
torch.empty((self.n_routed_experts, self.gating_dim))
121+
)
122+
# TODO: is this needed? This is not "Complementary Sequence-Wise Auxiliary Loss"
123+
# self.e_score_correction_bias = nn.Parameter(torch.rand((self.num_experts)))
124+
125+
def forward(
126+
self, x: torch.Tensor, expert_bias: torch.Tensor = None
127+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128+
"""
129+
TODO: We haven't implement the group-based routing (node limit routing) yet, and currently EP is not supporting node limit routing yet.
130+
Args:
131+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
132+
133+
Returns:
134+
routed_input (torch.Tensor):
135+
Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
136+
token_indices (torch.Tensor):
137+
Token indices for routed_input with shape ``(bs*slen*top_k,)``.
138+
num_local_tokens_per_expert (torch.Tensor):
139+
Number of tokens assigned to each expert with shape ``(num_experts,)``.
140+
"""
141+
# scores shape (bs*slen, num_experts)
142+
scores = F.linear(x.type, self.weight, None)
143+
144+
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
145+
if self.use_sigmoid:
146+
scores = torch.sigmoid(scores.to(torch.float32))
147+
else:
148+
scores = F.softmax(scores.to(torch.float32), dim=1)
149+
150+
# top scores shape (bs*slen, top_k)
151+
# NOTE: The expert_bias is only used for routing. The gating value
152+
# top_scores is still derived from the original scores.
153+
_, selected_experts_indices = torch.topk(
154+
scores + expert_bias, k=self.top_k, dim=1
155+
)
156+
top_scores = scores.gather(dim=1, index=selected_experts_indices)
157+
158+
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
159+
num_local_tokens_per_expert = torch.histc(
160+
selected_experts_indices.view(-1),
161+
bins=self.num_experts,
162+
min=0,
163+
max=self.num_experts,
164+
)
165+
# token_indices_experts_sorted shape (bs*slen*top_k,)
166+
token_indices_experts_sorted = torch.argsort(
167+
selected_experts_indices.view(-1), stable=True
168+
)
169+
top_scores = top_scores.view(-1)[token_indices_experts_sorted]
170+
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
171+
172+
top_scores = (
173+
top_scores * self.route_sclaing_factor
174+
) # must multiply the scaling factor
175+
176+
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
177+
178+
179+
class MoE(nn.Module):
180+
def __init__(self, model_args: DeepseekV3ModelArgs):
181+
182+
# n_routed_experts: int = 64
183+
# n_shared_experts: int = 2
184+
# n_activated_experts: int = 6
185+
# score_func: Literal["softmax", "sigmoid"] = "softmax"
186+
# route_scale: float = 1.0
187+
188+
super().__init__()
189+
dim = model_args.dim
190+
191+
num_experts = model_args.n_routed_experts
192+
hidden_dim = model_args.moe_inter_dim
193+
top_k = model_args.n_activated_experts
194+
route_scaling_factor = model_args.route_scale
195+
196+
self.use_grouped_mm = model_args.use_grouped_mm
197+
self.experts = GroupedExperts(
198+
dim=dim,
199+
hidden_dim=hidden_dim,
200+
num_experts=num_experts,
201+
use_grouped_mm=self.use_grouped_mm,
202+
)
203+
self.router = TokenChoiceTopKRouter(
204+
num_experts=num_experts,
205+
top_k=top_k,
206+
use_sigmoid=model_args.score_func == "sigmoid",
207+
route_sclaing_factor=route_scaling_factor,
208+
)
209+
self.shared_expert = (
210+
# Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517
211+
GroupedExperts(
212+
dim=dim,
213+
hidden_dim=hidden_dim * model_args.n_shared_experts,
214+
num_experts=1,
215+
use_grouped_mm=self.use_grouped_mm,
216+
)
217+
if model_args.n_shared_experts > 0
218+
else None
219+
)
220+
221+
# auxiliary-loss-free load balancing
222+
self.load_balance_coeff = model_args.load_balance_coeff
223+
# the fields below are defined even when load_balance_coeff is None
224+
# to make initialization and checkpointing code simpler
225+
self.register_buffer(
226+
"expert_bias",
227+
torch.zeros(num_experts, dtype=torch.float32),
228+
persistent=True,
229+
)
230+
self.register_buffer(
231+
"tokens_per_expert",
232+
torch.zeros(num_experts, dtype=torch.float32),
233+
persistent=True,
234+
)
235+
236+
# NOTE: forward hook, forward pre hook, or backward pre hook
237+
# would conflict with activation checkpointing
238+
if self.load_balance_coeff is not None and self.load_balance_coeff > 0:
239+
self.register_full_backward_hook(self._update_expert_bias)
240+
241+
# TODO: double check the bias update logic. It aligns with the paper.
242+
def _update_expert_bias(self, *_):
243+
expert_bias_delta = self.load_balance_coeff * torch.sign(
244+
self.tokens_per_expert.mean() - self.tokens_per_expert
245+
)
246+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
247+
self.expert_bias.add_(expert_bias_delta)
248+
249+
self.tokens_per_expert.zero_()
250+
251+
def forward(self, x: torch.Tensor) -> torch.Tensor:
252+
"""
253+
Args:
254+
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
255+
256+
Returns:
257+
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
258+
"""
259+
bs, slen, dim = x.shape
260+
261+
# top_scores and selected_indices shape (bs*slen*top_k,)
262+
# num_local_tokens_per_expert shape (num_experts,)
263+
(
264+
top_scores,
265+
token_indices,
266+
num_local_tokens_per_expert,
267+
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)
268+
269+
# will be used to update the expert bias for load balancing
270+
self.tokens_per_expert += num_local_tokens_per_expert
271+
272+
# shape (bs*slen*top_k, dim)
273+
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
274+
275+
# shape (bs*slen*top_k, dim)
276+
routed_input = torch.gather(
277+
x.view(-1, dim),
278+
dim=0,
279+
index=token_indices,
280+
)
281+
282+
if self.use_grouped_mm:
283+
# NOTE: In order to use torch._grouped_mm, we need to make sure
284+
# the number of tokens each expert gets is a multiple of 16.
285+
# The following kernel helps achieve this via padding, without
286+
# incurring synchronization between device and host.
287+
from torchtitan.experiments.kernels.moe.indices import (
288+
generate_permute_indices,
289+
)
290+
291+
ALIGN_SIZE_M = 16
292+
293+
with torch.no_grad():
294+
(
295+
permuted_indices,
296+
num_local_tokens_per_expert,
297+
_,
298+
) = generate_permute_indices(
299+
num_local_tokens_per_expert,
300+
self.experts.num_experts,
301+
1,
302+
ALIGN_SIZE_M,
303+
)
304+
token_indices = torch.vstack(
305+
(token_indices, token_indices.new_zeros((dim)))
306+
)
307+
token_indices = token_indices[permuted_indices, :]
308+
routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim))))
309+
routed_input = routed_input[permuted_indices, :]
310+
else:
311+
# NOTE: this would incur a synchronization between device and host
312+
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist()
313+
314+
# shape (bs*slen*top_k, dim)
315+
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
316+
routed_output = routed_output * top_scores.unsqueeze(-1)
317+
318+
# shared expert
319+
if self.shared_expert is not None:
320+
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
321+
bs * slen, dim
322+
)
323+
else:
324+
out = torch.zeros_like(x.reshape(bs * slen, dim))
325+
326+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
327+
out = out.reshape(bs, slen, dim)
328+
return out
329+
330+
def init_weights(
331+
self,
332+
init_std: float,
333+
buffer_device: torch.device,
334+
):
335+
self.experts.init_weights(init_std)
336+
self.router.init_weights(init_std)
337+
if self.shared_expert is not None:
338+
self.shared_expert.init_weights(init_std)
339+
340+
with torch.device(buffer_device):
341+
self.expert_bias = torch.zeros(
342+
self.experts.num_experts, dtype=torch.float32
343+
)
344+
self.tokens_per_expert = torch.zeros(
345+
self.experts.num_experts, dtype=torch.float32
346+
)

0 commit comments

Comments
 (0)