Skip to content

Commit 0275957

Browse files
committed
test
1 parent 488ef61 commit 0275957

File tree

7 files changed

+106
-78
lines changed

7 files changed

+106
-78
lines changed

torchtitan/experiments/kernels/moe/indices.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def fill_indices_wrapper(
7272
write_offsets: torch.Tensor,
7373
experts_per_rank: int,
7474
num_ranks: int,
75-
total_size: int,
75+
max_len: int,
7676
block_size: int = 128,
77-
max_blocks: int = 1024,
77+
max_blocks: int = 1024, # cap on total number of blocks to launch
7878
):
79-
# Allocate exact size needed instead of max_len
79+
# preallocate output
8080
permuted_indices = torch.full(
81-
(total_size,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
81+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
8282
)
8383

8484
# write offsets is per local expert...
@@ -99,37 +99,39 @@ def fill_indices_wrapper(
9999
return permuted_indices
100100

101101

102-
# used for reference testing only
103-
104-
102+
# reference
105103
def fill_indices_cpu(
106104
tokens_per_expert_group: torch.Tensor,
107105
start_index_values: torch.Tensor,
108106
write_offsets: torch.Tensor,
109107
experts_per_rank: int,
110108
num_ranks: int,
111-
total_size: int, # Changed from max_len to actual required size
109+
max_len: int,
112110
):
113-
# Allocate exact size needed
111+
# We need to preallocate the output - we ignore device and force it on cpu
112+
# device = tokens_per_expert_group.device
114113
permuted_indices = torch.full(
115-
(total_size,),
114+
(max_len,),
116115
-1,
117116
dtype=torch.int32,
118-
)
119-
117+
) # device=device)
120118
# Fill the permuted indices
119+
# For each local expert
121120
for e in range(experts_per_rank):
122121
write_start = write_offsets[e].item()
122+
# For each remote rank
123123
for r in range(num_ranks):
124124
i = r * experts_per_rank + e
125125
start_index = start_index_values[i].item()
126126
length = tokens_per_expert_group[i].item()
127+
# Fill in the indices
127128
if length > 0:
128-
end_idx = min(write_start + length, total_size)
129+
end_idx = min(write_start + length, max_len)
129130
permuted_indices[write_start:end_idx] = torch.arange(
130131
start_index,
131132
start_index + (end_idx - write_start),
132133
dtype=torch.int32,
134+
# device=device,
133135
)
134136
write_start += length
135137
return permuted_indices
@@ -139,22 +141,24 @@ def generate_permute_indices(
139141
tokens_per_expert_group: torch.Tensor,
140142
experts_per_rank: int,
141143
num_ranks: int,
144+
max_len: int,
142145
alignment: int,
143146
use_cpu: bool = False,
144147
):
145148
"""
146149
Prepare permutation indices and the number of tokens for each expert.
147-
Modified version that returns a tensor of size sum(m_sizes) instead of max_len.
148150
149151
Args:
150152
tokens_per_expert_group: number of tokens for each expert from all ranks.
151153
experts_per_rank: number of experts per rank.
152154
num_ranks: number of ranks.
155+
max_len: maximum length of the output index vector.
153156
alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts.
154157
use_cpu: whether to use CPU implementation.
155158
159+
156160
Returns:
157-
permuted_indices: Tensor of indices with size sum(m_sizes), that map original token order to the expert-grouped order.
161+
permuted_indices: Tensor of indices that map original token order to the expert-grouped order.
158162
m_sizes: aligned number of tokens for each expert (padded to alignment boundary).
159163
m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens.
160164
@@ -165,7 +169,7 @@ def generate_permute_indices(
165169
| 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
166170
"""
167171

168-
# prefix sum to get start index of each expert
172+
# prefix sum to get start index of each expert (parallel scan kernel in future?)
169173
start_index_values = (
170174
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
171175
)
@@ -182,12 +186,10 @@ def generate_permute_indices(
182186
)
183187

184188
# additional prefix sum to get write offset of each expert in permuted_indices
189+
# write offsets is per local expert, not global
185190
m_offsets = torch.cumsum(m_sizes, 0)
186191
write_offsets = m_offsets - m_sizes
187192

188-
# Calculate the actual total size needed
189-
total_size = m_offsets[-1]
190-
191193
# Select the implementation to use
192194
if use_cpu:
193195
permuted_indices = fill_indices_cpu(
@@ -196,16 +198,16 @@ def generate_permute_indices(
196198
write_offsets,
197199
experts_per_rank,
198200
num_ranks,
199-
total_size,
201+
max_len,
200202
)
201-
else: # gpu
203+
else:
202204
permuted_indices = fill_indices_wrapper(
203205
tokens_per_expert_group,
204206
start_index_values,
205207
write_offsets,
206208
experts_per_rank,
207209
num_ranks,
208-
total_size,
210+
max_len,
209211
)
210212

211213
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
@@ -225,17 +227,14 @@ def simple_test():
225227
alignment = 32
226228
# Use the GPU kernel
227229
permuted_indices_gpu, m_sizes, _ = generate_permute_indices(
228-
tokens_per_expert_group,
229-
experts_per_rank,
230-
num_ranks,
231-
alignment,
232-
use_cpu=False,
230+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
233231
)
234232
# Use the CPU method
235233
permuted_indices_cpu, m_sizes, _ = generate_permute_indices(
236234
tokens_per_expert_group,
237235
experts_per_rank,
238236
num_ranks,
237+
max_len,
239238
alignment,
240239
use_cpu=True,
241240
)
@@ -273,15 +272,16 @@ def test_with_zero_tokens():
273272
tokens_per_expert_group,
274273
experts_per_rank,
275274
num_ranks,
275+
max_len,
276276
alignment,
277-
use_cpu=False,
278277
)
279278

280279
# Use the CPU method
281280
permuted_indices_cpu, m_sizes_cpu, m_offsets_cpu = generate_permute_indices(
282281
tokens_per_expert_group,
283282
experts_per_rank,
284283
num_ranks,
284+
max_len,
285285
alignment,
286286
use_cpu=True,
287287
)

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
dim=256,
3333
inter_dim=10944,
3434
moe_inter_dim=1408,
35-
n_layers=3,
36-
n_dense_layers=1,
35+
n_layers=1,
36+
n_dense_layers=0, # no FFN layer, all MoE layers
3737
n_heads=16,
38-
n_routed_experts=8,
39-
n_shared_experts=2,
40-
n_activated_experts=3,
38+
n_routed_experts=2, # hang only happens when n_routed_experts > n_activated_experts
39+
n_shared_experts=1,
40+
n_activated_experts=1,
4141
route_scale=1.0,
4242
q_lora_rank=0,
4343
kv_lora_rank=512,

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
7575
n_limited_groups: int = 1
7676
score_func: Literal["softmax", "sigmoid"] = "softmax"
7777
route_scale: float = 1.0
78-
use_grouped_mm: bool = False
79-
load_balance_coeff: float | None = 1e-3
78+
use_grouped_mm: bool = True
79+
load_balance_coeff: float = 1e-3
8080
# Multi-Head Latent Attention (MLA)
8181
q_lora_rank: int = 0
8282
kv_lora_rank: int = 512

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchtitan.protocols.train_spec import ModelProtocol
1515

1616
from .args import DeepSeekV3ModelArgs
17-
from .moe import MoE
17+
from .moe import FeedForward, MoE
1818

1919

2020
# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
@@ -260,42 +260,6 @@ def init_weights(self, init_std: float):
260260
self.q_norm.reset_parameters()
261261

262262

263-
class FeedForward(nn.Module):
264-
"""
265-
FeedForward module
266-
267-
Args:
268-
dim (int): Input dimension.
269-
hidden_dim (int): Hidden dimension of the feedforward layer.
270-
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
271-
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
272-
273-
Attributes:
274-
w1 (Linear): Linear transformation for the first layer.
275-
w2 (Linear): Linear transformation for the second layer.
276-
w3 (Linear): Linear transformation for the third layer.
277-
278-
"""
279-
280-
def __init__(
281-
self,
282-
dim: int,
283-
hidden_dim: int,
284-
):
285-
super().__init__()
286-
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
287-
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
288-
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
289-
290-
def forward(self, x: torch.Tensor) -> torch.Tensor:
291-
return self.w2(F.silu(self.w1(x)) * self.w3(x))
292-
293-
def init_weights(self, init_std: float = 0.02):
294-
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
295-
for linear in (self.w2, self.w3):
296-
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
297-
298-
299263
class TransformerBlock(nn.Module):
300264
"""
301265
Transformer block with attention and feed-forward layers.
@@ -316,6 +280,7 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
316280

317281
# TODO: Need to revisit the weight initialization for the TransformerBlock
318282
self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5
283+
self.layer_id = layer_id
319284

320285
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
321286
"""
@@ -330,8 +295,10 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
330295
"""
331296
x = x + self.attention(self.attention_norm(x), freqs_cis)
332297
if self.moe_enabled:
298+
print(f"In TransformerBlock {self.layer_id}: MoE is enabled")
333299
x = x + self.moe(self.ffn_norm(x))
334300
else:
301+
print(f"In TransformerBlock {self.layer_id}: FFN is enabled")
335302
x = x + self.feed_forward(self.ffn_norm(x))
336303
return x
337304

@@ -360,6 +327,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
360327

361328
self.layers = torch.nn.ModuleDict()
362329
for layer_id in range(model_args.n_layers):
330+
print(f"Create layer: {layer_id}")
363331
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
364332

365333
self.norm = nn.RMSNorm(model_args.dim)

torchtitan/models/deepseek_v3/model/moe.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,42 @@
1111
from .args import DeepSeekV3ModelArgs
1212

1313

14+
class FeedForward(nn.Module):
15+
"""
16+
FeedForward module
17+
18+
Args:
19+
dim (int): Input dimension.
20+
hidden_dim (int): Hidden dimension of the feedforward layer.
21+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
22+
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
23+
24+
Attributes:
25+
w1 (Linear): Linear transformation for the first layer.
26+
w2 (Linear): Linear transformation for the second layer.
27+
w3 (Linear): Linear transformation for the third layer.
28+
29+
"""
30+
31+
def __init__(
32+
self,
33+
dim: int,
34+
hidden_dim: int,
35+
):
36+
super().__init__()
37+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
38+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
39+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
40+
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
43+
44+
def init_weights(self, init_std: float = 0.02):
45+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
46+
for linear in (self.w2, self.w3):
47+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
48+
49+
1450
# Reference: torchtitan/experiments/llama4/model/
1551
class GroupedExperts(nn.Module):
1652
def __init__(
@@ -212,11 +248,17 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
212248
GroupedExperts(
213249
dim=dim,
214250
hidden_dim=hidden_dim * model_args.n_shared_experts,
215-
num_experts=1,
251+
num_experts=1, # Here needs to be 1 to make it equivalent to the MLP
216252
use_grouped_mm=self.use_grouped_mm,
217253
)
218254
if model_args.n_shared_experts > 0
219255
else None
256+
# FeedForward(
257+
# dim=dim,
258+
# hidden_dim=hidden_dim * model_args.n_shared_experts,
259+
# )
260+
# if model_args.n_shared_experts > 0
261+
# else None
220262
)
221263

222264
# auxiliary-loss-free load balancing
@@ -266,6 +308,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
266308
num_local_tokens_per_expert,
267309
) = self.router(x.reshape(bs * slen, dim), self.expert_bias)
268310

311+
print(
312+
"In MoE, top_scores shape: ",
313+
top_scores.shape,
314+
"token_indices: ",
315+
token_indices.shape,
316+
"num_local_tokens: ",
317+
num_local_tokens_per_expert.shape,
318+
)
319+
269320
# will be used to update the expert bias for load balancing
270321
self.tokens_per_expert += num_local_tokens_per_expert
271322

@@ -299,6 +350,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
299350
num_local_tokens_per_expert,
300351
self.experts.num_experts,
301352
1,
353+
token_indices[0] + self.experts.num_experts * ALIGN_SIZE_M,
302354
ALIGN_SIZE_M,
303355
)
304356
token_indices = torch.vstack(
@@ -311,8 +363,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
311363
# NOTE: this would incur a synchronization between device and host
312364
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist()
313365

366+
print("Num local tokens per expert: ", num_local_tokens_per_expert)
314367
# shape (bs*slen*top_k, dim)
315-
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
368+
routed_output = self.experts(
369+
routed_input, num_local_tokens_per_expert
370+
) # torch.Size([16384(bsz), 256])
371+
print("Routed output shape: ", routed_output.shape)
316372
routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
317373
x.dtype
318374
)
@@ -321,10 +377,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
321377
if self.shared_expert is not None:
322378
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
323379
bs * slen, dim
324-
)
380+
) # torch.Size([16384, 256]) None
325381
else:
326382
out = torch.zeros_like(x.reshape(bs * slen, dim))
327383

384+
print(
385+
"Out shape: ", out.shape, out.grad.shape if out.grad is not None else None
386+
)
387+
328388
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
329389
out = out.reshape(bs, slen, dim)
330390
return out

0 commit comments

Comments
 (0)