Skip to content

Commit a4f56fb

Browse files
committed
a scales working, b scales not working
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 90ea3c7 commit a4f56fb

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def make_tensors(config: BatchedMMConfig):
7777
@pytest.mark.parametrize(
7878
"dtype",
7979
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
80-
@pytest.mark.parametrize("block_shape", [None])#, [128, 128]])
81-
@pytest.mark.parametrize("per_act_token_quant", [False])#, True])
80+
@pytest.mark.parametrize("block_shape", [[128, 128]]) # [None])#, [128, 128]])
81+
@pytest.mark.parametrize("per_act_token_quant", [False, True])# [False])# ,True])
8282
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
8383
N: int, dtype: torch.dtype,
8484
block_shape: Optional[list[int]],
@@ -141,6 +141,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
141141

142142
assert A_q.dtype == B_q.dtype
143143

144+
#A_scale.fill_(1)
145+
B_scale.fill_(1)
146+
144147
invoke_moe_batched_triton_kernel(
145148
A_q,
146149
B_q,
@@ -183,7 +186,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
183186
torch.float32: (1e-2, 1e-2),
184187
}[test_output.dtype]
185188

186-
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
189+
if False:
190+
torch.set_printoptions(profile="full")
191+
print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}")
192+
print(f"TRITON {test_output.shape}\n{test_output}")
193+
194+
#torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
187195
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
188196
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
189197

@@ -239,6 +247,7 @@ def test_fused_moe_batched_experts(
239247
per_act_token_quant=per_act_token_quant,
240248
)
241249

250+
# TODO remove
242251
torch.set_printoptions(profile="full")
243252

244253
with set_current_vllm_config(vllm_config):

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def moe_mmk(
5050
compute_type: tl.constexpr,
5151
use_w8a8: tl.constexpr,
5252
use_w8a16: tl.constexpr,
53-
per_channel_quant: tl.constexpr,
53+
per_act_token_quant: tl.constexpr,
5454
):
5555

5656
offs_k = tl.arange(0, BLOCK_K)
@@ -63,25 +63,33 @@ def moe_mmk(
6363
if use_w8a8:
6464
# block-wise
6565
if group_k > 0 and group_n > 0:
66-
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm #+ (expert_id * stride_ase)
66+
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
6767
offs_bsn = offs_n // group_n
68-
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
69-
offs_bsn * stride_bsn)
68+
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
7069

71-
# channel-wise
72-
elif per_channel_quant:
73-
# TODO: probably not correct
74-
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
70+
# per act token
71+
elif per_act_token_quant:
72+
# Load per-token scale for activations
73+
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
74+
a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:,None]
75+
76+
b_scale_ptrs = b_scale_ptr + offs_n[None, :] * stride_bsn
7577
b_scale = tl.load(b_scale_ptrs)
78+
79+
7680
# Load per-token scale for activations
7781
# + (expert_id * stride_ase)??
78-
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
79-
a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
82+
#a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
83+
#a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
84+
85+
# TODO: probably not correct
86+
#b_scale_ptrs = b_scale_ptr + expert_id * stride_bse #+ offs_n[None, :] * stride_bsn
87+
#b_scale = tl.load(b_scale_ptrs)
8088

8189
# tensor-wise
8290
else:
83-
a_scale = tl.load(a_scale_ptr) # + (expert_id * stride_ase)
84-
b_scale = tl.load(b_scale_ptr + expert_id * stride_bse)
91+
a_scale = tl.load(a_scale_ptr)
92+
b_scale = tl.load(b_scale_ptr)
8593

8694
# -----------------------------------------------------------
8795
# Iterate to compute a block of the C matrix.
@@ -108,26 +116,33 @@ def moe_mmk(
108116
other=0.0)
109117
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
110118

111-
accumulator += tl.dot(a, b) * a_scale[:,
112-
None] * b_scale[None, :]
119+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
120+
elif False and per_act_token_quant:
121+
a_scale = tl.load(a_scale_ptrs + offs_k[None, :] * stride_ask,
122+
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
123+
other=0.0)
124+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
125+
126+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
113127
else:
114-
if use_w8a8:
115-
# acc used to enable fp8_fast_accum
116-
accumulator = tl.dot(a, b, acc=accumulator)
117-
else:
118-
accumulator += tl.dot(a, b)
128+
accumulator = tl.dot(a, b, acc=accumulator)
119129
else:
120130
accumulator += tl.dot(a, b)
131+
121132
# Advance the ptrs to the next K block.
122133
a_ptrs += BLOCK_K * stride_ak
123134
b_ptrs += BLOCK_K * stride_bk
124135

136+
if False and per_act_token_quant:
137+
a_scale_ptrs += BLOCK_K * stride_ask
138+
b_scale_ptrs += BLOCK_K * stride_bsk
139+
125140
if use_w8a16:
126141
accumulator = (accumulator * b_scale).to(compute_type)
127142
elif use_w8a8:
128143
if group_k > 0 and group_n > 0:
129144
accumulator = accumulator.to(compute_type)
130-
else:
145+
elif True or not per_act_token_quant:
131146
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
132147
else:
133148
accumulator = accumulator.to(compute_type)
@@ -169,7 +184,7 @@ def expert_triton_kernel(
169184
# Quantization schemes
170185
use_fp8_w8a8: tl.constexpr,
171186
use_int8_w8a16: tl.constexpr,
172-
per_channel_quant: tl.constexpr,
187+
per_act_token_quant: tl.constexpr,
173188
# Kernel config
174189
BLOCK_M: tl.constexpr,
175190
BLOCK_N: tl.constexpr,
@@ -181,6 +196,7 @@ def expert_triton_kernel(
181196
offs_k = tl.arange(0, BLOCK_K)
182197
mask_m = offs_m < M
183198

199+
# Make grids of a + b pointers
184200
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
185201
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
186202

@@ -217,7 +233,7 @@ def expert_triton_kernel(
217233
compute_type,
218234
use_fp8_w8a8,
219235
use_int8_w8a16,
220-
per_channel_quant)
236+
per_act_token_quant)
221237

222238
# store in C
223239
offs_cn = tl.arange(0, BLOCK_N)
@@ -266,17 +282,19 @@ def batched_triton_kernel(
266282
# Quantization schemes
267283
use_fp8_w8a8: tl.constexpr,
268284
use_int8_w8a16: tl.constexpr,
269-
per_channel_quant: tl.constexpr,
285+
per_act_token_quant: tl.constexpr,
270286
# Kernel config
271287
BLOCK_M: tl.constexpr,
272288
BLOCK_N: tl.constexpr,
273-
BLOCK_K: tl.constexpr):
289+
BLOCK_K: tl.constexpr,
290+
):
274291
expert_id = tl.program_id(axis=0)
275292
e_num_tokens = tl.load(expert_num_tokens + expert_id)
276293
if e_num_tokens == 0:
277294
# Early exit
278295
return
279296

297+
# axis 1 is M_blocks * N_blocks
280298
pid_mn = tl.program_id(axis=1)
281299
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
282300
num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -298,14 +316,15 @@ def batched_triton_kernel(
298316
cta_n_start * stride_cn)
299317

300318
if use_fp8_w8a8:
301-
a_scale_ptr = a_scale_ptr + (expert_id * stride_ase)
319+
a_scale_ptr = a_scale_ptr + expert_id * stride_ase
320+
b_scale_ptr = b_scale_ptr + expert_id * stride_bse
302321
# block-wise
303322
if group_k > 0 and group_n > 0:
304323
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
305-
b_scale_ptr = b_scale_ptr + (expert_id * stride_bse)
306-
elif per_channel_quant:
324+
# b group advancement?
325+
elif False and per_act_token_quant:
307326
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
308-
b_scale_ptr = b_scale_ptr + (expert_id * stride_bse) + cta_n_start * stride_bsn
327+
b_scale_ptr = b_scale_ptr + cta_n_start * stride_bsn
309328

310329
expert_triton_kernel(
311330
a_ptr,
@@ -338,7 +357,7 @@ def batched_triton_kernel(
338357
# Quantization schemes
339358
use_fp8_w8a8,
340359
use_int8_w8a16,
341-
per_channel_quant,
360+
per_act_token_quant,
342361
# Kernel config
343362
BLOCK_M,
344363
BLOCK_N,

0 commit comments

Comments
 (0)