|
25 | 25 |
|
26 | 26 | from tests.kernels.quant_utils import per_block_cast_to_fp8
|
27 | 27 | from .deepep_utils import ProcessGroupInfo, parallel_launch
|
| 28 | +from .utils import make_test_weights |
28 | 29 |
|
29 | 30 | has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
30 | 31 | has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
@@ -70,43 +71,10 @@ def make_block_quant_fp8_weights(
|
70 | 71 | block_size: list[int],
|
71 | 72 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
72 | 73 | """
|
73 |
| - Return weights w1, w2, w1q, w2q, w1_scale, w2_scale |
| 74 | + Return weights w1q, w2q, w1_scale, w2_scale |
74 | 75 | """
|
75 |
| - dtype = torch.bfloat16 |
76 |
| - |
77 |
| - fp8_info = torch.finfo(torch.float8_e4m3fn) |
78 |
| - fp8_max, fp8_min = fp8_info.max, fp8_info.min |
79 |
| - |
80 |
| - w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 |
81 |
| - w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) |
82 |
| - |
83 |
| - w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10 |
84 |
| - w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) |
85 |
| - |
86 |
| - block_n, block_k = block_size[0], block_size[1] |
87 |
| - n_tiles_w1 = ((2 * n) + block_n - 1) // block_n |
88 |
| - k_tiles_w1 = (k + block_k - 1) // block_k |
89 |
| - n_tiles_w2 = (k + block_n - 1) // block_n |
90 |
| - k_tiles_w2 = (n + block_k - 1) // block_k |
91 |
| - |
92 |
| - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) |
93 |
| - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) |
94 |
| - |
95 |
| - w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), |
96 |
| - device="cuda", |
97 |
| - dtype=torch.float32) |
98 |
| - w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), |
99 |
| - device="cuda", |
100 |
| - dtype=torch.float32) |
101 |
| - |
102 |
| - assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128) |
103 |
| - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] |
104 |
| - |
105 |
| - for i in range(e): |
106 |
| - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) |
107 |
| - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) |
108 |
| - |
109 |
| - return w1, w2, w1_s, w2_s |
| 76 | + w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) |
| 77 | + return w1q, w2q, w1_scale, w2_scale |
110 | 78 |
|
111 | 79 |
|
112 | 80 | @dataclasses.dataclass
|
@@ -460,10 +428,14 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
460 | 428 | @pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
461 | 429 | @requires_deep_ep
|
462 | 430 | @requires_deep_gemm
|
463 |
| -def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, |
464 |
| - int], num_experts: int, topk: int, |
465 |
| - use_fp8_dispatch: bool, block_size: list[int], |
466 |
| - world_dp_size: tuple[int, int]): |
| 431 | +def test_ll_deepep_deepgemm_moe( |
| 432 | + mnk: tuple[int, int, int], |
| 433 | + num_experts: int, |
| 434 | + topk: int, |
| 435 | + use_fp8_dispatch: bool, |
| 436 | + block_size: list[int], |
| 437 | + world_dp_size: tuple[int, int], |
| 438 | +): |
467 | 439 | """
|
468 | 440 | Tests for Low-Latency DeepEP + DeepGemm integration.
|
469 | 441 | """
|
|
0 commit comments