@@ -39,6 +39,7 @@ def moe_mmk(
39
39
# Offsets and masks
40
40
offs_m ,
41
41
offs_n ,
42
+ offs_bn ,
42
43
mask_m ,
43
44
# Block size for block-wise quantization
44
45
group_n : tl .constexpr ,
@@ -64,7 +65,7 @@ def moe_mmk(
64
65
# block-wise
65
66
if group_k > 0 and group_n > 0 :
66
67
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
67
- offs_bsn = offs_n // group_n
68
+ offs_bsn = offs_bn // group_n
68
69
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
69
70
70
71
# per act token
@@ -142,7 +143,7 @@ def moe_mmk(
142
143
elif use_w8a8 :
143
144
if group_k > 0 and group_n > 0 :
144
145
accumulator = accumulator .to (compute_type )
145
- elif True or not per_act_token_quant :
146
+ else : #if True or not per_act_token_quant:
146
147
accumulator = (accumulator * a_scale * b_scale ).to (compute_type )
147
148
else :
148
149
accumulator = accumulator .to (compute_type )
@@ -178,6 +179,8 @@ def expert_triton_kernel(
178
179
stride_bse ,
179
180
stride_bsk ,
180
181
stride_bsn ,
182
+ # offsets
183
+ offs_bn ,
181
184
# Blockwise quantization data
182
185
group_n ,
183
186
group_k ,
@@ -222,6 +225,7 @@ def expert_triton_kernel(
222
225
# Offsets and masks
223
226
offs_m ,
224
227
offs_n ,
228
+ offs_bn ,
225
229
mask_m ,
226
230
# Block size for block-wise quantization
227
231
group_n ,
@@ -315,12 +319,15 @@ def batched_triton_kernel(
315
319
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316
320
cta_n_start * stride_cn )
317
321
322
+ offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ).to (tl .int64 )) % N
323
+
318
324
if use_fp8_w8a8 :
319
325
a_scale_ptr = a_scale_ptr + expert_id * stride_ase
320
326
b_scale_ptr = b_scale_ptr + expert_id * stride_bse
321
327
# block-wise
322
328
if group_k > 0 and group_n > 0 :
323
329
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
330
+ #b_scale_ptr = b_scale_ptr + offs_bn * stride_bsn
324
331
# b group advancement?
325
332
elif False and per_act_token_quant :
326
333
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
@@ -351,6 +358,8 @@ def batched_triton_kernel(
351
358
stride_bse ,
352
359
stride_bsk ,
353
360
stride_bsn ,
361
+ # offsets
362
+ offs_bn ,
354
363
# Blockwise quantization data
355
364
group_n ,
356
365
group_k ,
@@ -404,12 +413,13 @@ def invoke_moe_batched_triton_kernel(
404
413
if B_scale is not None :
405
414
if B_scale .ndim == 1 :
406
415
stride_bse = 1
407
- stride_bsn = 0
408
416
stride_bsk = 0
417
+ stride_bsn = 0
409
418
else :
410
419
stride_bse = B_scale .stride (0 )
411
- stride_bsn = B_scale .stride (1 )
412
420
stride_bsk = B_scale .stride (2 )
421
+ stride_bsn = B_scale .stride (1 )
422
+
413
423
else :
414
424
stride_bse = 0
415
425
stride_bsk = 0
0 commit comments