Skip to content

Commit a7ca36b

Browse files
committed
try to fix lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 7aeb5c6 commit a7ca36b

File tree

3 files changed

+43
-44
lines changed

3 files changed

+43
-44
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tests.kernels.moe.utils import (batched_moe,
1212
make_quantized_test_activations,
1313
make_test_weights, triton_moe)
14-
from tests.kernels.quant_utils import native_w8a8_block_matmul
14+
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
1515
from tests.kernels.utils import torch_experts
1616
from vllm.config import VllmConfig, set_current_vllm_config
1717
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@@ -68,43 +68,6 @@ def make_tensors(config: BatchedMMConfig):
6868
return BatchedMMTensors(A, B, C, num_expert_tokens)
6969

7070

71-
def ref_impl(
72-
A: torch.Tensor,
73-
B: torch.Tensor,
74-
C: torch.Tensor,
75-
num_expert_tokens: torch.Tensor,
76-
A_scale: Optional[torch.Tensor],
77-
B_scale: Optional[torch.Tensor],
78-
block_shape: Optional[list[int]],
79-
) -> torch.Tensor:
80-
assert (A.dtype.itemsize > 1
81-
or (A_scale is not None and B_scale is not None))
82-
83-
num_expert_tokens_cpu = num_expert_tokens.clone()
84-
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
85-
num_experts = num_expert_tokens.size(0)
86-
87-
f32 = torch.float32
88-
bf16 = torch.bfloat16
89-
90-
for e in range(num_experts):
91-
num_tokens = num_expert_tokens_cpu[e]
92-
if A.dtype.itemsize == 1 and block_shape is not None:
93-
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
94-
block_shape, C.dtype)
95-
C[e, :num_tokens, :] = tmp[:num_tokens, :]
96-
elif A.dtype.itemsize == 1 and block_shape is None:
97-
C[e, :num_tokens, :] = (
98-
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(bf16)
99-
@ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(bf16))
100-
else:
101-
assert A_scale is None
102-
assert B_scale is None
103-
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
104-
105-
return C
106-
107-
10871
@pytest.mark.parametrize("num_experts", [8, 16, 32])
10972
@pytest.mark.parametrize("max_tokens_per_expert",
11073
[32, 64, 128, 192, 224, 256, 512])
@@ -193,7 +156,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
193156
block_shape=block_shape,
194157
)
195158

196-
ref_output = ref_impl(
159+
ref_output = native_batched_masked_quant_matmul(
197160
A,
198161
B,
199162
ref_output,
@@ -203,8 +166,10 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
203166
None,
204167
)
205168

206-
q_ref_output = ref_impl(A_q, B_q, q_ref_output, num_expert_tokens, A_scale,
207-
B_scale, block_shape)
169+
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
170+
num_expert_tokens,
171+
A_scale, B_scale,
172+
block_shape)
208173

209174
rtol, atol = {
210175
torch.float16: (6e-2, 6e-2),

tests/kernels/quant_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,38 @@ def per_block_cast_to_fp8(
233233
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
234234
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
235235
return x_scaled_sub, scales
236+
237+
238+
def native_batched_masked_quant_matmul(
239+
A: torch.Tensor,
240+
B: torch.Tensor,
241+
C: torch.Tensor,
242+
num_expert_tokens: torch.Tensor,
243+
A_scale: Optional[torch.Tensor],
244+
B_scale: Optional[torch.Tensor],
245+
block_shape: Optional[list[int]],
246+
) -> torch.Tensor:
247+
num_expert_tokens_cpu = num_expert_tokens.clone()
248+
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
249+
num_experts = num_expert_tokens.size(0)
250+
251+
f32 = torch.float32
252+
253+
for e in range(num_experts):
254+
num_tokens = num_expert_tokens_cpu[e]
255+
if A.dtype.itemsize == 1 and block_shape is not None:
256+
assert A_scale is not None and B_scale is not None
257+
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
258+
block_shape, C.dtype)
259+
C[e, :num_tokens, :] = tmp[:num_tokens, :]
260+
elif A.dtype.itemsize == 1 and block_shape is None:
261+
assert A_scale is not None and B_scale is not None
262+
C[e, :num_tokens, :] = (
263+
(A[e, :num_tokens, :].to(f32) * A_scale[e]).to(C.dtype)
264+
@ (B[e].transpose(0, 1).to(f32) * B_scale[e]).to(C.dtype))
265+
else:
266+
assert A_scale is None
267+
assert B_scale is None
268+
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
269+
270+
return C

tests/kernels/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,9 +1076,6 @@ def torch_experts(
10761076
or (expert_map is not None
10771077
and global_num_experts == expert_map.shape[0]))
10781078

1079-
assert (quant_dtype is None
1080-
or (w1_scale is not None and w2_scale is not None))
1081-
10821079
M, K = a.shape
10831080
topk = topk_ids.shape[1]
10841081

@@ -1103,6 +1100,8 @@ def torch_experts(
11031100
tmp2 = SiluAndMul()(tmp1)
11041101
out[mask] = tmp2 @ w2[i].transpose(0, 1)
11051102
elif block_shape is not None:
1103+
assert (a_scale is not None and w1_scale is not None
1104+
and w2_scale is not None)
11061105
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
11071106
w1_scale[i], block_shape,
11081107
out.dtype)

0 commit comments

Comments
 (0)