Skip to content

Commit 77f95b9

Browse files
committed
test
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent bbe888d commit 77f95b9

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
@dataclass
1515
class BatchedMMConfig:
16-
dtype: torch.dtype
16+
in_dtype: torch.dtype
17+
out_dtype: torch.dtype
1718
num_experts: int
1819
max_tokens_per_expert: int
1920
K: int
@@ -29,26 +30,25 @@ class BatchedMMTensors:
2930

3031
@staticmethod
3132
def make_tensors(config: BatchedMMConfig):
32-
if config.dtype == torch.torch.float8_e4m3fn:
33-
config_dtype = torch.bfloat16
33+
if config.in_dtype == torch.torch.float8_e4m3fn:
34+
config_in_dtype = torch.bfloat16
3435
else:
35-
config_dtype = config.dtype
36+
config_in_dtype = config.in_dtype
3637

3738
A = torch.randn(
3839
(config.num_experts, config.max_tokens_per_expert, config.K),
3940
device="cuda",
40-
dtype=config_dtype) / 10
41+
dtype=config_in_dtype) / 10
4142
B = torch.randn((config.num_experts, config.N, config.K),
4243
device="cuda",
43-
dtype=config_dtype)
44+
dtype=config_in_dtype)
4445
C = torch.zeros(
4546
(config.num_experts, config.max_tokens_per_expert, config.N),
4647
device="cuda",
47-
dtype=config_dtype)
48+
dtype=config.out_dtype)
4849

49-
A = A.to(config.dtype)
50-
B = B.to(config.dtype)
51-
C = C.to(config.dtype)
50+
A = A.to(config.in_dtype)
51+
B = B.to(config.in_dtype)
5252

5353
num_expert_tokens = torch.randint(low=0,
5454
high=config.max_tokens_per_expert,
@@ -136,11 +136,19 @@ def ref_impl(
136136
for e in range(num_experts):
137137
num_tokens = num_expert_tokens_cpu[e]
138138
if A.dtype == torch.torch.float8_e4m3fn:
139-
tmp = native_w8a8_block_matmul(A[e, :, :],
140-
B[e].transpose(0, 1),
141-
A_scale,
142-
B_scale,
143-
[1,1])#block_shape)
139+
if False:
140+
tmp = native_w8a8_block_matmul(A[e, :, :],
141+
B[e].transpose(0, 1),
142+
A_scale,
143+
B_scale,
144+
[1,1])#block_shape)
145+
else:
146+
import vllm._custom_ops as ops
147+
tmp = ops.cutlass_scaled_mm(A[e, :, :],
148+
B[e].transpose(0, 1),
149+
A_scale,
150+
B_scale,
151+
C.dtype)
144152
C[e, :num_tokens, :] = tmp[:num_tokens, :]
145153
else:
146154
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
@@ -159,14 +167,21 @@ def ref_impl(
159167
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
160168
N: int, dtype: torch.dtype):
161169

162-
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
170+
if dtype == torch.torch.float8_e4m3fn:
171+
in_dtype = dtype
172+
out_dtype = torch.bfloat16
173+
else:
174+
in_dtype = dtype
175+
out_dtype = dtype
176+
177+
config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N)
163178
tensors = BatchedMMTensors.make_tensors(config)
164179

165180
test_output = tensors.C
166181
ref_output = test_output.clone()
182+
ref_output2 = test_output.clone()
167183

168184
compute_tl_dtype = {
169-
torch.torch.float8_e4m3fn: tl.bfloat16,
170185
torch.float16: tl.float16,
171186
torch.bfloat16: tl.bfloat16,
172187
torch.float32: tl.float32
@@ -175,12 +190,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
175190
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
176191
block_shape = [16, 16, 32] # 16 for k if not fp8
177192

178-
print(f"tensors.A {tensors.A.shape}")
179-
print(f"tensors.B {tensors.B.shape}")
193+
#print(f"tensors.A {tensors.A.shape}")
194+
#print(f"tensors.B {tensors.B.shape}")
180195

181196
if use_fp8_w8a8:
182-
A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
183-
B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
197+
#A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
198+
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
199+
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
200+
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
184201
else:
185202
A_scale = None
186203
B_scale = None
@@ -205,19 +222,29 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
205222
"BLOCK_SIZE_K": block_shape[2],
206223
})
207224

208-
ref_output = ref_impl(tensors.A,
209-
tensors.B,
225+
ref_output = ref_output.to(dtype=out_dtype)
226+
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
227+
tensors.B.to(dtype=out_dtype),
210228
ref_output,
211229
tensors.num_expert_tokens,
212230
A_scale,
213231
B_scale,
214232
block_shape[-2:])
215233

234+
ref_output2 = ref_impl(tensors.A,
235+
tensors.B,
236+
ref_output2,
237+
tensors.num_expert_tokens,
238+
A_scale,
239+
B_scale,
240+
block_shape[-2:])
241+
216242
rtol, atol = {
217-
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
218243
torch.float16: (6e-2, 6e-2),
219244
torch.bfloat16: (6e-2, 6e-2),
220245
torch.float32: (1e-2, 1e-2),
221246
}[test_output.dtype]
222247

223-
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
248+
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
249+
if not use_fp8_w8a8:
250+
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)