Skip to content

Commit dc7fd23

Browse files
authored
Revert "[kernels] move generate_permute_indices to using exact sizes per expert needed, instead of max_len" (#1340)
Reverts #1254 for the reasons @ngimel and @kwen2501 mentioned in the original PR. It was introduced because of the issue #1237 But the right way of fixing the issue is to fix my wrong way of doing permutation of both `routed_input` and `token_indices`; instead we should permute `routed_input` and unpermute. https://github.com/pytorch/torchtitan/blob/adae6b6fd850054d6e1c04ecc2a4cf84abc68056/torchtitan/experiments/llama4/model/moe.py#L300-L305 I will send out a PR soon to fix the error in Llama 4 MoE impl.
1 parent adae6b6 commit dc7fd23

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

torchtitan/experiments/kernels/moe/indices.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ def fill_indices_wrapper(
7272
write_offsets: torch.Tensor,
7373
experts_per_rank: int,
7474
num_ranks: int,
75-
total_size: int,
75+
max_len: int,
7676
block_size: int = 128,
77-
max_blocks: int = 1024,
77+
max_blocks: int = 1024, # cap on total number of blocks to launch
7878
):
79-
# Allocate exact size needed instead of max_len
79+
# preallocate output
8080
permuted_indices = torch.full(
81-
(total_size,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
81+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
8282
)
8383

8484
# write offsets is per local expert...
@@ -99,37 +99,39 @@ def fill_indices_wrapper(
9999
return permuted_indices
100100

101101

102-
# used for reference testing only
103-
104-
102+
# reference
105103
def fill_indices_cpu(
106104
tokens_per_expert_group: torch.Tensor,
107105
start_index_values: torch.Tensor,
108106
write_offsets: torch.Tensor,
109107
experts_per_rank: int,
110108
num_ranks: int,
111-
total_size: int, # Changed from max_len to actual required size
109+
max_len: int,
112110
):
113-
# Allocate exact size needed
111+
# We need to preallocate the output - we ignore device and force it on cpu
112+
# device = tokens_per_expert_group.device
114113
permuted_indices = torch.full(
115-
(total_size,),
114+
(max_len,),
116115
-1,
117116
dtype=torch.int32,
118-
)
119-
117+
) # device=device)
120118
# Fill the permuted indices
119+
# For each local expert
121120
for e in range(experts_per_rank):
122121
write_start = write_offsets[e].item()
122+
# For each remote rank
123123
for r in range(num_ranks):
124124
i = r * experts_per_rank + e
125125
start_index = start_index_values[i].item()
126126
length = tokens_per_expert_group[i].item()
127+
# Fill in the indices
127128
if length > 0:
128-
end_idx = min(write_start + length, total_size)
129+
end_idx = min(write_start + length, max_len)
129130
permuted_indices[write_start:end_idx] = torch.arange(
130131
start_index,
131132
start_index + (end_idx - write_start),
132133
dtype=torch.int32,
134+
# device=device,
133135
)
134136
write_start += length
135137
return permuted_indices
@@ -139,22 +141,24 @@ def generate_permute_indices(
139141
tokens_per_expert_group: torch.Tensor,
140142
experts_per_rank: int,
141143
num_ranks: int,
144+
max_len: int,
142145
alignment: int,
143146
use_cpu: bool = False,
144147
):
145148
"""
146149
Prepare permutation indices and the number of tokens for each expert.
147-
Modified version that returns a tensor of size sum(m_sizes) instead of max_len.
148150
149151
Args:
150152
tokens_per_expert_group: number of tokens for each expert from all ranks.
151153
experts_per_rank: number of experts per rank.
152154
num_ranks: number of ranks.
155+
max_len: maximum length of the output index vector.
153156
alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts.
154157
use_cpu: whether to use CPU implementation.
155158
159+
156160
Returns:
157-
permuted_indices: Tensor of indices with size sum(m_sizes), that map original token order to the expert-grouped order.
161+
permuted_indices: Tensor of indices that map original token order to the expert-grouped order.
158162
m_sizes: aligned number of tokens for each expert (padded to alignment boundary).
159163
m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens.
160164
@@ -165,7 +169,7 @@ def generate_permute_indices(
165169
| 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
166170
"""
167171

168-
# prefix sum to get start index of each expert
172+
# prefix sum to get start index of each expert (parallel scan kernel in future?)
169173
start_index_values = (
170174
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
171175
)
@@ -182,12 +186,10 @@ def generate_permute_indices(
182186
)
183187

184188
# additional prefix sum to get write offset of each expert in permuted_indices
189+
# write offsets is per local expert, not global
185190
m_offsets = torch.cumsum(m_sizes, 0)
186191
write_offsets = m_offsets - m_sizes
187192

188-
# Calculate the actual total size needed
189-
total_size = m_offsets[-1]
190-
191193
# Select the implementation to use
192194
if use_cpu:
193195
permuted_indices = fill_indices_cpu(
@@ -196,16 +198,16 @@ def generate_permute_indices(
196198
write_offsets,
197199
experts_per_rank,
198200
num_ranks,
199-
total_size,
201+
max_len,
200202
)
201-
else: # gpu
203+
else:
202204
permuted_indices = fill_indices_wrapper(
203205
tokens_per_expert_group,
204206
start_index_values,
205207
write_offsets,
206208
experts_per_rank,
207209
num_ranks,
208-
total_size,
210+
max_len,
209211
)
210212

211213
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
@@ -225,17 +227,14 @@ def simple_test():
225227
alignment = 32
226228
# Use the GPU kernel
227229
permuted_indices_gpu, m_sizes, _ = generate_permute_indices(
228-
tokens_per_expert_group,
229-
experts_per_rank,
230-
num_ranks,
231-
alignment,
232-
use_cpu=False,
230+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
233231
)
234232
# Use the CPU method
235233
permuted_indices_cpu, m_sizes, _ = generate_permute_indices(
236234
tokens_per_expert_group,
237235
experts_per_rank,
238236
num_ranks,
237+
max_len,
239238
alignment,
240239
use_cpu=True,
241240
)
@@ -273,15 +272,16 @@ def test_with_zero_tokens():
273272
tokens_per_expert_group,
274273
experts_per_rank,
275274
num_ranks,
275+
max_len,
276276
alignment,
277-
use_cpu=False,
278277
)
279278

280279
# Use the CPU method
281280
permuted_indices_cpu, m_sizes_cpu, m_offsets_cpu = generate_permute_indices(
282281
tokens_per_expert_group,
283282
experts_per_rank,
284283
num_ranks,
284+
max_len,
285285
alignment,
286286
use_cpu=True,
287287
)

0 commit comments

Comments
 (0)