Skip to content

Commit eab92d3

Browse files
committed
mm baselines work
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent bb5a4ed commit eab92d3

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
166166
B,
167167
ref_output,
168168
num_expert_tokens,
169-
None,
170-
None,
171-
None,
172169
)
173170

174171
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
175172
num_expert_tokens,
176173
A_scale, B_scale,
177-
block_shape)
174+
block_shape,
175+
per_act_token_quant)
178176

179177
rtol, atol = {
180178
torch.float16: (6e-2, 6e-2),
@@ -183,7 +181,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
183181
}[test_output.dtype]
184182

185183
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
186-
187184
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
188185
#torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
189186

tests/kernels/quant_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.platforms import current_platform
99
from vllm.utils import cdiv
10+
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
1011

1112
# Using the default value (240.0) from pytorch will cause accuracy
1213
# issue on dynamic quantization models. Here use 224.0 for rocm.
@@ -235,14 +236,23 @@ def per_block_cast_to_fp8(
235236
return x_scaled_sub, scales
236237

237238

239+
def _dequant(t: torch.Tensor, scale: torch.Tensor, block_shape, per_act_token_quant) -> torch.Tensor:
240+
f32 = torch.float32
241+
if per_act_token_quant or block_shape is None:
242+
return t.to(f32) * scale
243+
else:
244+
return t.to(f32) * group_broadcast(scale, t.shape)
245+
246+
238247
def native_batched_masked_quant_matmul(
239248
A: torch.Tensor,
240249
B: torch.Tensor,
241250
C: torch.Tensor,
242251
num_expert_tokens: torch.Tensor,
243-
A_scale: Optional[torch.Tensor],
244-
B_scale: Optional[torch.Tensor],
245-
block_shape: Optional[list[int]],
252+
A_scale: Optional[torch.Tensor] = None,
253+
B_scale: Optional[torch.Tensor] = None,
254+
block_shape: Optional[list[int]] = None,
255+
per_act_token_quant: bool = False,
246256
) -> torch.Tensor:
247257
num_expert_tokens_cpu = num_expert_tokens.clone()
248258
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
@@ -259,9 +269,9 @@ def native_batched_masked_quant_matmul(
259269
C[e, :num_tokens, :] = tmp[:num_tokens, :]
260270
elif A.dtype.itemsize == 1 and block_shape is None:
261271
assert A_scale is not None and B_scale is not None
262-
C[e, :num_tokens, :] = (
263-
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(C.dtype)
264-
@ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(C.dtype))
272+
A_dq = _dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
273+
B_dq = _dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
274+
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
265275
else:
266276
assert A_scale is None
267277
assert B_scale is None

0 commit comments

Comments
 (0)