@@ -72,13 +72,13 @@ def fill_indices_wrapper(
72
72
write_offsets : torch .Tensor ,
73
73
experts_per_rank : int ,
74
74
num_ranks : int ,
75
- total_size : int ,
75
+ max_len : int ,
76
76
block_size : int = 128 ,
77
- max_blocks : int = 1024 ,
77
+ max_blocks : int = 1024 , # cap on total number of blocks to launch
78
78
):
79
- # Allocate exact size needed instead of max_len
79
+ # preallocate output
80
80
permuted_indices = torch .full (
81
- (total_size ,), - 1 , dtype = torch .int32 , device = tokens_per_expert_group .device
81
+ (max_len ,), - 1 , dtype = torch .int32 , device = tokens_per_expert_group .device
82
82
)
83
83
84
84
# write offsets is per local expert...
@@ -99,37 +99,39 @@ def fill_indices_wrapper(
99
99
return permuted_indices
100
100
101
101
102
- # used for reference testing only
103
-
104
-
102
+ # reference
105
103
def fill_indices_cpu (
106
104
tokens_per_expert_group : torch .Tensor ,
107
105
start_index_values : torch .Tensor ,
108
106
write_offsets : torch .Tensor ,
109
107
experts_per_rank : int ,
110
108
num_ranks : int ,
111
- total_size : int , # Changed from max_len to actual required size
109
+ max_len : int ,
112
110
):
113
- # Allocate exact size needed
111
+ # We need to preallocate the output - we ignore device and force it on cpu
112
+ # device = tokens_per_expert_group.device
114
113
permuted_indices = torch .full (
115
- (total_size ,),
114
+ (max_len ,),
116
115
- 1 ,
117
116
dtype = torch .int32 ,
118
- )
119
-
117
+ ) # device=device)
120
118
# Fill the permuted indices
119
+ # For each local expert
121
120
for e in range (experts_per_rank ):
122
121
write_start = write_offsets [e ].item ()
122
+ # For each remote rank
123
123
for r in range (num_ranks ):
124
124
i = r * experts_per_rank + e
125
125
start_index = start_index_values [i ].item ()
126
126
length = tokens_per_expert_group [i ].item ()
127
+ # Fill in the indices
127
128
if length > 0 :
128
- end_idx = min (write_start + length , total_size )
129
+ end_idx = min (write_start + length , max_len )
129
130
permuted_indices [write_start :end_idx ] = torch .arange (
130
131
start_index ,
131
132
start_index + (end_idx - write_start ),
132
133
dtype = torch .int32 ,
134
+ # device=device,
133
135
)
134
136
write_start += length
135
137
return permuted_indices
@@ -139,22 +141,24 @@ def generate_permute_indices(
139
141
tokens_per_expert_group : torch .Tensor ,
140
142
experts_per_rank : int ,
141
143
num_ranks : int ,
144
+ max_len : int ,
142
145
alignment : int ,
143
146
use_cpu : bool = False ,
144
147
):
145
148
"""
146
149
Prepare permutation indices and the number of tokens for each expert.
147
- Modified version that returns a tensor of size sum(m_sizes) instead of max_len.
148
150
149
151
Args:
150
152
tokens_per_expert_group: number of tokens for each expert from all ranks.
151
153
experts_per_rank: number of experts per rank.
152
154
num_ranks: number of ranks.
155
+ max_len: maximum length of the output index vector.
153
156
alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts.
154
157
use_cpu: whether to use CPU implementation.
155
158
159
+
156
160
Returns:
157
- permuted_indices: Tensor of indices with size sum(m_sizes), that map original token order to the expert-grouped order.
161
+ permuted_indices: Tensor of indices that map original token order to the expert-grouped order.
158
162
m_sizes: aligned number of tokens for each expert (padded to alignment boundary).
159
163
m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens.
160
164
@@ -165,7 +169,7 @@ def generate_permute_indices(
165
169
| 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
166
170
"""
167
171
168
- # prefix sum to get start index of each expert
172
+ # prefix sum to get start index of each expert (parallel scan kernel in future?)
169
173
start_index_values = (
170
174
torch .cumsum (tokens_per_expert_group , 0 ) - tokens_per_expert_group
171
175
)
@@ -182,12 +186,10 @@ def generate_permute_indices(
182
186
)
183
187
184
188
# additional prefix sum to get write offset of each expert in permuted_indices
189
+ # write offsets is per local expert, not global
185
190
m_offsets = torch .cumsum (m_sizes , 0 )
186
191
write_offsets = m_offsets - m_sizes
187
192
188
- # Calculate the actual total size needed
189
- total_size = m_offsets [- 1 ]
190
-
191
193
# Select the implementation to use
192
194
if use_cpu :
193
195
permuted_indices = fill_indices_cpu (
@@ -196,16 +198,16 @@ def generate_permute_indices(
196
198
write_offsets ,
197
199
experts_per_rank ,
198
200
num_ranks ,
199
- total_size ,
201
+ max_len ,
200
202
)
201
- else : # gpu
203
+ else :
202
204
permuted_indices = fill_indices_wrapper (
203
205
tokens_per_expert_group ,
204
206
start_index_values ,
205
207
write_offsets ,
206
208
experts_per_rank ,
207
209
num_ranks ,
208
- total_size ,
210
+ max_len ,
209
211
)
210
212
211
213
return permuted_indices , m_sizes , m_offsets .to (torch .int32 )
@@ -225,17 +227,14 @@ def simple_test():
225
227
alignment = 32
226
228
# Use the GPU kernel
227
229
permuted_indices_gpu , m_sizes , _ = generate_permute_indices (
228
- tokens_per_expert_group ,
229
- experts_per_rank ,
230
- num_ranks ,
231
- alignment ,
232
- use_cpu = False ,
230
+ tokens_per_expert_group , experts_per_rank , num_ranks , max_len , alignment
233
231
)
234
232
# Use the CPU method
235
233
permuted_indices_cpu , m_sizes , _ = generate_permute_indices (
236
234
tokens_per_expert_group ,
237
235
experts_per_rank ,
238
236
num_ranks ,
237
+ max_len ,
239
238
alignment ,
240
239
use_cpu = True ,
241
240
)
@@ -273,15 +272,16 @@ def test_with_zero_tokens():
273
272
tokens_per_expert_group ,
274
273
experts_per_rank ,
275
274
num_ranks ,
275
+ max_len ,
276
276
alignment ,
277
- use_cpu = False ,
278
277
)
279
278
280
279
# Use the CPU method
281
280
permuted_indices_cpu , m_sizes_cpu , m_offsets_cpu = generate_permute_indices (
282
281
tokens_per_expert_group ,
283
282
experts_per_rank ,
284
283
num_ranks ,
284
+ max_len ,
285
285
alignment ,
286
286
use_cpu = True ,
287
287
)
0 commit comments