@@ -63,17 +63,15 @@ def moe_mmk(
63
63
if use_w8a8 :
64
64
# block-wise
65
65
if group_k > 0 and group_n > 0 :
66
- a_scale_ptrs = a_scale_ptr + (offs_m * stride_asm
67
- ) #+ (expert_id * stride_ase)
66
+ a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase)
68
67
offs_bsn = offs_n // group_n
69
- b_scale_ptrs = (b_scale_ptr +
70
- offs_bsn * stride_bsn ) + expert_id * stride_bse
68
+ b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
69
+ offs_bsn * stride_bsn )
71
70
72
71
# channel-wise
73
72
elif per_channel_quant :
74
73
# TODO: probably not correct
75
- b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n [
76
- None , :] * stride_bsn
74
+ b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n [None , :] * stride_bsn
77
75
b_scale = tl .load (b_scale_ptrs )
78
76
# Load per-token scale for activations
79
77
# + (expert_id * stride_ase)??
@@ -300,16 +298,14 @@ def batched_triton_kernel(
300
298
cta_n_start * stride_cn )
301
299
302
300
if use_fp8_w8a8 :
301
+ a_scale_ptr = a_scale_ptr + (expert_id * stride_ase )
303
302
# block-wise
304
- if (group_k > 0 and group_n > 0 ) or per_channel_quant :
305
- a_scale_ptr = a_scale_ptr + (expert_id *
306
- stride_ase ) + cta_m_start * stride_asm
307
- #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
308
- # (?) b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
309
- # channel-wise or tensor-wise
310
- else :
311
- a_scale_ptr = a_scale_ptr + (expert_id * stride_ase )
312
- #b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
303
+ if group_k > 0 and group_n > 0 :
304
+ a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
305
+ b_scale_ptr = b_scale_ptr + (expert_id * stride_bse )
306
+ elif per_channel_quant :
307
+ a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
308
+ b_scale_ptr = b_scale_ptr + (expert_id * stride_bse ) + cta_n_start * stride_bsn
313
309
314
310
expert_triton_kernel (
315
311
a_ptr ,
@@ -532,6 +528,7 @@ def prepare(
532
528
self .max_num_tokens ,
533
529
hidden_dim )
534
530
531
+ # empty?
535
532
b_a1_scale = torch .zeros (scale_shape ,
536
533
dtype = torch .float32 ,
537
534
device = a1 .device )
0 commit comments