Skip to content

Commit e8088c6

Browse files
committed
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 2a69594 commit e8088c6

16 files changed

+105
-103
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import pytest
77
import torch
88

9-
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
10-
native_w8a8_block_matmul,
11-
per_block_cast_to_fp8)
129
from tests.kernels.moe.utils import make_test_weights
10+
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
11+
native_w8a8_block_matmul)
1312
from vllm.config import VllmConfig, set_current_vllm_config
1413
from vllm.model_executor.layers.activation import SiluAndMul
1514
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -56,7 +55,8 @@
5655
SEEDS = [0]
5756

5857

59-
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
58+
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
59+
block_shape):
6060
"""Fused moe with block-wise quantization using native torch."""
6161
B, D = a.shape
6262
topk = topk_ids.size(1)
@@ -116,7 +116,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
116116
a = torch.randn((M, K), dtype=dtype) / 10
117117
score = torch.randn((M, E), dtype=dtype)
118118

119-
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.float8_e4m3fn,
119+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
120+
N,
121+
K,
122+
dtype,
123+
torch.float8_e4m3fn,
120124
per_act_token_quant=False,
121125
block_shape=block_size)
122126

@@ -203,8 +207,8 @@ def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
203207
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
204208

205209

206-
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
207-
block_shape):
210+
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, topk_weight,
211+
topk_ids, block_shape):
208212
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
209213
num_groups = w1.shape[0]
210214
M, K = a.shape
@@ -265,7 +269,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
265269
a = torch.randn((M, K), dtype=dtype) / 10
266270
score = torch.randn((M, E), dtype=dtype)
267271

268-
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.float8_e4m3fn,
272+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
273+
N,
274+
K,
275+
dtype,
276+
torch.float8_e4m3fn,
269277
per_act_token_quant=False,
270278
block_shape=block_size)
271279

@@ -281,12 +289,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
281289

282290
# Set the context to avoid lots of warning spam.
283291
with set_current_vllm_config(vllm_config):
284-
if False and M >= 128:
292+
if M >= 128:
285293
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
286-
topk_weights, topk_ids, block_size)
294+
topk_weights, topk_ids,
295+
block_size)
287296
else:
288-
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
289-
topk_ids, block_size)
297+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s,
298+
topk_weights, topk_ids,
299+
block_size)
290300

291301
if use_compile:
292302
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,

tests/kernels/moe/test_block_int8.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import pytest
77
import torch
88

9+
from tests.kernels.moe.utils import make_test_weights
910
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
1011
native_w8a8_block_matmul)
11-
from tests.kernels.moe.utils import make_test_weights
1212
from vllm.config import VllmConfig, set_current_vllm_config
1313
from vllm.model_executor.layers.activation import SiluAndMul
1414
from vllm.model_executor.layers.fused_moe import fused_moe
@@ -84,34 +84,15 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
8484
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
8585
native torch reference."""
8686
torch.manual_seed(seed)
87-
# Use a smaller factor for scale initialization to prevent large
88-
# values/overflow especially when output dtype might be float16
89-
# factor_for_scale = 1e-2
90-
# int8_info = torch.iinfo(torch.int8)
91-
# int8_max, int8_min = int8_info.max, int8_info.min
9287

9388
a = torch.randn((M, K), dtype=dtype) / 10
9489
score = torch.randn((M, E), dtype=dtype)
9590

96-
# w1_fp32 = (torch.rand(
97-
# (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
98-
# w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
99-
100-
# w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
101-
# w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
102-
103-
# block_n, block_k = block_size[0], block_size[1]
104-
# n_tiles_w1 = (2 * N + block_n - 1) // block_n
105-
# n_tiles_w2 = (K + block_n - 1) // block_n
106-
# k_tiles_w1 = (K + block_k - 1) // block_k
107-
# k_tiles_w2 = (N + block_k - 1) // block_k
108-
109-
# w1_s = (torch.rand(
110-
# (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
111-
# w2_s = (torch.rand(
112-
# (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
113-
114-
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, N, K, dtype, torch.int8,
91+
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
92+
N,
93+
K,
94+
dtype,
95+
torch.int8,
11596
per_act_token_quant=False,
11697
block_shape=block_size)
11798

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,15 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
100100
if False:
101101
_, a_scale = ops.scaled_fp8_quant(
102102
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
103-
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
104-
a_scale,
105-
use_per_token_if_dynamic=per_act_token)
103+
a_q, _ = ops.scaled_fp8_quant(
104+
moe_tensors_fp16.a,
105+
a_scale,
106+
use_per_token_if_dynamic=per_act_token)
106107
else:
107-
a_q, a_scale = ops.scaled_fp8_quant(moe_tensors_fp16.a,
108-
None,
109-
use_per_token_if_dynamic=per_act_token)
108+
a_q, a_scale = ops.scaled_fp8_quant(
109+
moe_tensors_fp16.a,
110+
None,
111+
use_per_token_if_dynamic=per_act_token)
110112

111113
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
112114
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
@@ -209,7 +211,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
209211
'topk_ids': topk_ids,
210212
'w1_scale': moe_tensors.w1_scale,
211213
'w2_scale': moe_tensors.w2_scale,
212-
'a1_scale': None #moe_tensors.a_scale
214+
'a1_scale': None #moe_tensors.a_scale
213215
}
214216

215217
num_experts = moe_tensors.w1.size(0)
@@ -262,7 +264,8 @@ def test_cutlass_moe_8_bit_no_graph(
262264

263265
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
264266

265-
# Note 5.5 only needed for larger problem sizes, 5 works ok for the rest.
267+
# Note 5.5 only needed for larger problem sizes, 5 works ok for
268+
# the rest.
266269
torch.testing.assert_close(triton_output,
267270
cutlass_output,
268271
atol=5.5e-2,

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2222
per_token_group_quant_fp8)
2323
from vllm.platforms import current_platform
24-
from vllm.utils import cdiv
2524

26-
from tests.kernels.quant_utils import per_block_cast_to_fp8
2725
from .deepep_utils import ProcessGroupInfo, parallel_launch
2826
from .utils import make_test_weights
2927

@@ -73,7 +71,8 @@ def make_block_quant_fp8_weights(
7371
"""
7472
Return weights w1q, w2q, w1_scale, w2_scale
7573
"""
76-
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
74+
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
75+
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
7776
return w1q, w2q, w1_scale, w2_scale
7877

7978

tests/kernels/moe/test_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import vllm.model_executor.layers.fused_moe # noqa
1818
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
1919
from vllm.config import VllmConfig, set_current_vllm_config
20-
from vllm.forward_context import set_forward_context
2120
from vllm.distributed.parallel_state import init_distributed_environment
21+
from vllm.forward_context import set_forward_context
2222
from vllm.model_executor.layers.fused_moe import fused_moe
2323
from vllm.model_executor.layers.fused_moe.fused_moe import (
2424
fused_topk, modular_triton_fused_moe)

tests/kernels/moe/utils.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
import torch
66

77
import vllm._custom_ops as ops
8+
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
9+
per_block_cast_to_int8)
810
from vllm.model_executor.layers.fused_moe import fused_experts
911
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1012
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
1113
from vllm.model_executor.layers.fused_moe.modular_kernel import (
1214
FusedMoEModularKernel)
1315
from vllm.model_executor.layers.fused_moe.utils import (
1416
moe_kernel_quantize_input)
15-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
16-
per_token_group_quant_fp8)
1717
from vllm.utils import round_up
18-
from tests.kernels.quant_utils import per_block_cast_to_fp8, per_block_cast_to_int8
18+
1919

2020
def triton_moe(
2121
a: torch.Tensor,
@@ -70,7 +70,7 @@ def batched_moe(
7070
max_num_tokens=max_num_tokens,
7171
world_size=1,
7272
dp_size=1,
73-
use_fp8_w8a8=quant_dtype==torch.float8_e4m3fn,
73+
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
7474
per_act_token_quant=per_act_token_quant,
7575
block_shape=block_shape,
7676
),
@@ -112,14 +112,19 @@ def naive_batched_moe(
112112
max_num_tokens=max_num_tokens,
113113
dp_size=1,
114114
world_size=1,
115-
use_fp8_w8a8=quant_dtype==torch.float8_e4m3fn,
115+
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
116116
per_act_token_quant=per_act_token_quant,
117117
block_shape=block_shape,
118118
),
119119
)
120120

121-
return fused_experts(a, w1, w2, topk_weight, topk_ids,
122-
w1_scale=w1_scale, w2_scale=w2_scale,
121+
return fused_experts(a,
122+
w1,
123+
w2,
124+
topk_weight,
125+
topk_ids,
126+
w1_scale=w1_scale,
127+
w2_scale=w2_scale,
123128
a1_scale=a1_scale,
124129
a2_scale=a2_scale)
125130

@@ -148,7 +153,8 @@ def make_quantized_test_activations(
148153
a_scale = None
149154

150155
if quant_dtype is not None:
151-
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, "only fp8/int8 supported"
156+
assert (quant_dtype == torch.float8_e4m3fn
157+
or quant_dtype == torch.int8), "only fp8/int8 supported"
152158
a_q = torch.zeros_like(a, dtype=quant_dtype)
153159
a_scale = [None] * E
154160
for e in range(E):
@@ -169,7 +175,8 @@ def moe_quantize_weights(
169175
per_token_quant: bool,
170176
block_shape: Optional[list[int]],
171177
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
172-
assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, "only fp8/int8 supported"
178+
assert (quant_dtype == torch.float8_e4m3fn
179+
or quant_dtype == torch.int8), "only fp8/int8 supported"
173180

174181
if block_shape is not None:
175182
assert not per_token_quant
@@ -179,9 +186,11 @@ def moe_quantize_weights(
179186
w, w_s = per_block_cast_to_fp8(w, block_shape)
180187
else:
181188
if quant_dtype == torch.int8:
182-
w, w_s = ops.scaled_int8_quant(w, w_s, use_per_token_if_dynamic=per_token_quant)
189+
w, w_s = ops.scaled_int8_quant(
190+
w, w_s, use_per_token_if_dynamic=per_token_quant)
183191
else:
184-
w, w_s = ops.scaled_fp8_quant(w, w_s, use_per_token_if_dynamic=per_token_quant)
192+
w, w_s = ops.scaled_fp8_quant(
193+
w, w_s, use_per_token_if_dynamic=per_token_quant)
185194

186195
return w, w_s
187196

@@ -233,6 +242,8 @@ def make_test_weights(
233242
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
234243
torch.Tensor, Optional[torch.Tensor]]:
235244
return (
236-
*make_test_weight(e, 2*n, k, in_dtype, quant_dtype, block_shape, per_act_token_quant),
237-
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_act_token_quant),
245+
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
246+
per_act_token_quant),
247+
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
248+
per_act_token_quant),
238249
)

tests/kernels/quant_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
import torch
77

8+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
9+
group_broadcast)
810
from vllm.platforms import current_platform
911
from vllm.utils import round_up
10-
from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
1112

1213
# Using the default value (240.0) from pytorch will cause accuracy
1314
# issue on dynamic quantization models. Here use 224.0 for rocm.
@@ -220,17 +221,17 @@ def native_per_token_group_quant_int8(x,
220221

221222
DEFAULT_BLOCK_SHAPE = [128, 128]
222223

224+
223225
def per_block_cast_to_fp8(
224226
x: torch.Tensor,
225227
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
226228
) -> tuple[torch.Tensor, torch.Tensor]:
227229
block_m, block_n = block_shape
228230
assert x.dim() == 2
229231
m, n = x.shape
230-
x_padded = torch.zeros(
231-
(round_up(m, block_m), round_up(n, block_n)),
232-
dtype=x.dtype,
233-
device=x.device)
232+
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
233+
dtype=x.dtype,
234+
device=x.device)
234235
x_padded[:m, :n] = x
235236
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
236237
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
@@ -248,10 +249,9 @@ def per_block_cast_to_int8(
248249
block_m, block_n = block_shape
249250
assert x.dim() == 2
250251
m, n = x.shape
251-
x_padded = torch.zeros(
252-
(round_up(m, block_m), round_up(n, block_n)),
253-
dtype=x.dtype,
254-
device=x.device)
252+
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
253+
dtype=x.dtype,
254+
device=x.device)
255255
x_padded[:m, :n] = x
256256
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
257257
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
@@ -292,8 +292,6 @@ def native_batched_masked_quant_matmul(
292292
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
293293
num_experts = num_expert_tokens.size(0)
294294

295-
f32 = torch.float32
296-
297295
for e in range(num_experts):
298296
num_tokens = num_expert_tokens_cpu[e]
299297
if A.dtype.itemsize == 1 and block_shape is not None:
@@ -305,7 +303,8 @@ def native_batched_masked_quant_matmul(
305303
assert A_scale is not None and B_scale is not None
306304
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
307305
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
308-
C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
306+
C[e, :num_tokens, :] = (
307+
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
309308
else:
310309
assert A_scale is None
311310
assert B_scale is None

vllm/_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,8 @@ def scaled_fp8_quant(
12761276
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
12771277
else:
12781278
# num_token_padding not implemented for this case
1279-
assert (scale.numel() == 1 and num_token_padding is None), f"{scale.shape} {num_token_padding}"
1279+
assert (scale.numel() == 1 and num_token_padding
1280+
is None), f"{scale.shape} {num_token_padding}"
12801281
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
12811282

12821283
return output, scale

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
88
from vllm.logger import init_logger
99
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
10-
from vllm.model_executor.layers.fused_moe.utils import (
11-
_resize_cache, per_token_group_quant_fp8)
1210
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1311
from vllm.triton_utils import tl, triton
1412

0 commit comments

Comments
 (0)