@@ -50,7 +50,7 @@ def moe_mmk(
50
50
compute_type : tl .constexpr ,
51
51
use_w8a8 : tl .constexpr ,
52
52
use_w8a16 : tl .constexpr ,
53
- per_channel_quant : tl .constexpr ,
53
+ per_act_token_quant : tl .constexpr ,
54
54
):
55
55
56
56
offs_k = tl .arange (0 , BLOCK_K )
@@ -63,25 +63,33 @@ def moe_mmk(
63
63
if use_w8a8 :
64
64
# block-wise
65
65
if group_k > 0 and group_n > 0 :
66
- a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase)
66
+ a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
67
67
offs_bsn = offs_n // group_n
68
- b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
69
- offs_bsn * stride_bsn )
68
+ b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
70
69
71
- # channel-wise
72
- elif per_channel_quant :
73
- # TODO: probably not correct
74
- b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n [None , :] * stride_bsn
70
+ # per act token
71
+ elif per_act_token_quant :
72
+ # Load per-token scale for activations
73
+ a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
74
+ a_scale = tl .load (a_scale_ptrs , mask = mask_m , other = 0.0 )[:,None ]
75
+
76
+ b_scale_ptrs = b_scale_ptr + offs_n [None , :] * stride_bsn
75
77
b_scale = tl .load (b_scale_ptrs )
78
+
79
+
76
80
# Load per-token scale for activations
77
81
# + (expert_id * stride_ase)??
78
- a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
79
- a_scale = tl .load (a_scale_ptrs , mask = mask_m , other = 0.0 )[:, None ]
82
+ #a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
83
+ #a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
84
+
85
+ # TODO: probably not correct
86
+ #b_scale_ptrs = b_scale_ptr + expert_id * stride_bse #+ offs_n[None, :] * stride_bsn
87
+ #b_scale = tl.load(b_scale_ptrs)
80
88
81
89
# tensor-wise
82
90
else :
83
- a_scale = tl .load (a_scale_ptr ) # + (expert_id * stride_ase)
84
- b_scale = tl .load (b_scale_ptr + expert_id * stride_bse )
91
+ a_scale = tl .load (a_scale_ptr )
92
+ b_scale = tl .load (b_scale_ptr )
85
93
86
94
# -----------------------------------------------------------
87
95
# Iterate to compute a block of the C matrix.
@@ -108,26 +116,33 @@ def moe_mmk(
108
116
other = 0.0 )
109
117
b_scale = tl .load (b_scale_ptrs + offs_ks * stride_bsk )
110
118
111
- accumulator += tl .dot (a , b ) * a_scale [:,
112
- None ] * b_scale [None , :]
119
+ accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
120
+ elif False and per_act_token_quant :
121
+ a_scale = tl .load (a_scale_ptrs + offs_k [None , :] * stride_ask ,
122
+ mask = mask_m [:, None ] & (offs_k [None , :] < K - k * BLOCK_K ),
123
+ other = 0.0 )
124
+ b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_K , other = 0.0 )
125
+
126
+ accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
113
127
else :
114
- if use_w8a8 :
115
- # acc used to enable fp8_fast_accum
116
- accumulator = tl .dot (a , b , acc = accumulator )
117
- else :
118
- accumulator += tl .dot (a , b )
128
+ accumulator = tl .dot (a , b , acc = accumulator )
119
129
else :
120
130
accumulator += tl .dot (a , b )
131
+
121
132
# Advance the ptrs to the next K block.
122
133
a_ptrs += BLOCK_K * stride_ak
123
134
b_ptrs += BLOCK_K * stride_bk
124
135
136
+ if False and per_act_token_quant :
137
+ a_scale_ptrs += BLOCK_K * stride_ask
138
+ b_scale_ptrs += BLOCK_K * stride_bsk
139
+
125
140
if use_w8a16 :
126
141
accumulator = (accumulator * b_scale ).to (compute_type )
127
142
elif use_w8a8 :
128
143
if group_k > 0 and group_n > 0 :
129
144
accumulator = accumulator .to (compute_type )
130
- else :
145
+ elif True or not per_act_token_quant :
131
146
accumulator = (accumulator * a_scale * b_scale ).to (compute_type )
132
147
else :
133
148
accumulator = accumulator .to (compute_type )
@@ -169,7 +184,7 @@ def expert_triton_kernel(
169
184
# Quantization schemes
170
185
use_fp8_w8a8 : tl .constexpr ,
171
186
use_int8_w8a16 : tl .constexpr ,
172
- per_channel_quant : tl .constexpr ,
187
+ per_act_token_quant : tl .constexpr ,
173
188
# Kernel config
174
189
BLOCK_M : tl .constexpr ,
175
190
BLOCK_N : tl .constexpr ,
@@ -181,6 +196,7 @@ def expert_triton_kernel(
181
196
offs_k = tl .arange (0 , BLOCK_K )
182
197
mask_m = offs_m < M
183
198
199
+ # Make grids of a + b pointers
184
200
a_ptrs = a_ptr + offs_m [:, None ] * stride_am + offs_k [None , :] * stride_ak
185
201
b_ptrs = b_ptr + offs_k [:, None ] * stride_bk + offs_n [None , :] * stride_bn
186
202
@@ -217,7 +233,7 @@ def expert_triton_kernel(
217
233
compute_type ,
218
234
use_fp8_w8a8 ,
219
235
use_int8_w8a16 ,
220
- per_channel_quant )
236
+ per_act_token_quant )
221
237
222
238
# store in C
223
239
offs_cn = tl .arange (0 , BLOCK_N )
@@ -266,17 +282,19 @@ def batched_triton_kernel(
266
282
# Quantization schemes
267
283
use_fp8_w8a8 : tl .constexpr ,
268
284
use_int8_w8a16 : tl .constexpr ,
269
- per_channel_quant : tl .constexpr ,
285
+ per_act_token_quant : tl .constexpr ,
270
286
# Kernel config
271
287
BLOCK_M : tl .constexpr ,
272
288
BLOCK_N : tl .constexpr ,
273
- BLOCK_K : tl .constexpr ):
289
+ BLOCK_K : tl .constexpr ,
290
+ ):
274
291
expert_id = tl .program_id (axis = 0 )
275
292
e_num_tokens = tl .load (expert_num_tokens + expert_id )
276
293
if e_num_tokens == 0 :
277
294
# Early exit
278
295
return
279
296
297
+ # axis 1 is M_blocks * N_blocks
280
298
pid_mn = tl .program_id (axis = 1 )
281
299
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
282
300
num_pid_n = tl .cdiv (N , BLOCK_N )
@@ -298,14 +316,15 @@ def batched_triton_kernel(
298
316
cta_n_start * stride_cn )
299
317
300
318
if use_fp8_w8a8 :
301
- a_scale_ptr = a_scale_ptr + (expert_id * stride_ase )
319
+ a_scale_ptr = a_scale_ptr + expert_id * stride_ase
320
+ b_scale_ptr = b_scale_ptr + expert_id * stride_bse
302
321
# block-wise
303
322
if group_k > 0 and group_n > 0 :
304
323
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
305
- b_scale_ptr = b_scale_ptr + ( expert_id * stride_bse )
306
- elif per_channel_quant :
324
+ # b group advancement?
325
+ elif False and per_act_token_quant :
307
326
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
308
- b_scale_ptr = b_scale_ptr + ( expert_id * stride_bse ) + cta_n_start * stride_bsn
327
+ b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
309
328
310
329
expert_triton_kernel (
311
330
a_ptr ,
@@ -338,7 +357,7 @@ def batched_triton_kernel(
338
357
# Quantization schemes
339
358
use_fp8_w8a8 ,
340
359
use_int8_w8a16 ,
341
- per_channel_quant ,
360
+ per_act_token_quant ,
342
361
# Kernel config
343
362
BLOCK_M ,
344
363
BLOCK_N ,
0 commit comments