13
13
14
14
@dataclass
15
15
class BatchedMMConfig :
16
- dtype : torch .dtype
16
+ in_dtype : torch .dtype
17
+ out_dtype : torch .dtype
17
18
num_experts : int
18
19
max_tokens_per_expert : int
19
20
K : int
@@ -29,26 +30,25 @@ class BatchedMMTensors:
29
30
30
31
@staticmethod
31
32
def make_tensors (config : BatchedMMConfig ):
32
- if config .dtype == torch .torch .float8_e4m3fn :
33
- config_dtype = torch .bfloat16
33
+ if config .in_dtype == torch .torch .float8_e4m3fn :
34
+ config_in_dtype = torch .bfloat16
34
35
else :
35
- config_dtype = config .dtype
36
+ config_in_dtype = config .in_dtype
36
37
37
38
A = torch .randn (
38
39
(config .num_experts , config .max_tokens_per_expert , config .K ),
39
40
device = "cuda" ,
40
- dtype = config_dtype ) / 10
41
+ dtype = config_in_dtype ) / 10
41
42
B = torch .randn ((config .num_experts , config .N , config .K ),
42
43
device = "cuda" ,
43
- dtype = config_dtype )
44
+ dtype = config_in_dtype )
44
45
C = torch .zeros (
45
46
(config .num_experts , config .max_tokens_per_expert , config .N ),
46
47
device = "cuda" ,
47
- dtype = config_dtype )
48
+ dtype = config . out_dtype )
48
49
49
- A = A .to (config .dtype )
50
- B = B .to (config .dtype )
51
- C = C .to (config .dtype )
50
+ A = A .to (config .in_dtype )
51
+ B = B .to (config .in_dtype )
52
52
53
53
num_expert_tokens = torch .randint (low = 0 ,
54
54
high = config .max_tokens_per_expert ,
@@ -136,11 +136,19 @@ def ref_impl(
136
136
for e in range (num_experts ):
137
137
num_tokens = num_expert_tokens_cpu [e ]
138
138
if A .dtype == torch .torch .float8_e4m3fn :
139
- tmp = native_w8a8_block_matmul (A [e , :, :],
140
- B [e ].transpose (0 , 1 ),
141
- A_scale ,
142
- B_scale ,
143
- [1 ,1 ])#block_shape)
139
+ if False :
140
+ tmp = native_w8a8_block_matmul (A [e , :, :],
141
+ B [e ].transpose (0 , 1 ),
142
+ A_scale ,
143
+ B_scale ,
144
+ [1 ,1 ])#block_shape)
145
+ else :
146
+ import vllm ._custom_ops as ops
147
+ tmp = ops .cutlass_scaled_mm (A [e , :, :],
148
+ B [e ].transpose (0 , 1 ),
149
+ A_scale ,
150
+ B_scale ,
151
+ C .dtype )
144
152
C [e , :num_tokens , :] = tmp [:num_tokens , :]
145
153
else :
146
154
C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
@@ -159,14 +167,21 @@ def ref_impl(
159
167
def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
160
168
N : int , dtype : torch .dtype ):
161
169
162
- config = BatchedMMConfig (dtype , num_experts , max_tokens_per_expert , K , N )
170
+ if dtype == torch .torch .float8_e4m3fn :
171
+ in_dtype = dtype
172
+ out_dtype = torch .bfloat16
173
+ else :
174
+ in_dtype = dtype
175
+ out_dtype = dtype
176
+
177
+ config = BatchedMMConfig (in_dtype , out_dtype , num_experts , max_tokens_per_expert , K , N )
163
178
tensors = BatchedMMTensors .make_tensors (config )
164
179
165
180
test_output = tensors .C
166
181
ref_output = test_output .clone ()
182
+ ref_output2 = test_output .clone ()
167
183
168
184
compute_tl_dtype = {
169
- torch .torch .float8_e4m3fn : tl .bfloat16 ,
170
185
torch .float16 : tl .float16 ,
171
186
torch .bfloat16 : tl .bfloat16 ,
172
187
torch .float32 : tl .float32
@@ -175,12 +190,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
175
190
use_fp8_w8a8 = dtype == torch .torch .float8_e4m3fn
176
191
block_shape = [16 , 16 , 32 ] # 16 for k if not fp8
177
192
178
- print (f"tensors.A { tensors .A .shape } " )
179
- print (f"tensors.B { tensors .B .shape } " )
193
+ # print(f"tensors.A {tensors.A.shape}")
194
+ # print(f"tensors.B {tensors.B.shape}")
180
195
181
196
if use_fp8_w8a8 :
182
- A_scale = torch .ones ((max_tokens_per_expert ,K ), dtype = torch .float32 , device = tensors .A .device )
183
- B_scale = torch .ones ((N , K ), dtype = torch .float32 , device = tensors .A .device )
197
+ #A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
198
+ #B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
199
+ A_scale = torch .ones (1 , dtype = torch .float32 , device = tensors .A .device )
200
+ B_scale = torch .ones (1 , dtype = torch .float32 , device = tensors .B .device )
184
201
else :
185
202
A_scale = None
186
203
B_scale = None
@@ -205,19 +222,29 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
205
222
"BLOCK_SIZE_K" : block_shape [2 ],
206
223
})
207
224
208
- ref_output = ref_impl (tensors .A ,
209
- tensors .B ,
225
+ ref_output = ref_output .to (dtype = out_dtype )
226
+ ref_output = ref_impl (tensors .A .to (dtype = out_dtype ),
227
+ tensors .B .to (dtype = out_dtype ),
210
228
ref_output ,
211
229
tensors .num_expert_tokens ,
212
230
A_scale ,
213
231
B_scale ,
214
232
block_shape [- 2 :])
215
233
234
+ ref_output2 = ref_impl (tensors .A ,
235
+ tensors .B ,
236
+ ref_output2 ,
237
+ tensors .num_expert_tokens ,
238
+ A_scale ,
239
+ B_scale ,
240
+ block_shape [- 2 :])
241
+
216
242
rtol , atol = {
217
- torch .torch .float8_e4m3fn : (6e-2 , 6e-2 ),
218
243
torch .float16 : (6e-2 , 6e-2 ),
219
244
torch .bfloat16 : (6e-2 , 6e-2 ),
220
245
torch .float32 : (1e-2 , 1e-2 ),
221
246
}[test_output .dtype ]
222
247
223
- torch .testing .assert_close (test_output , ref_output , atol = atol , rtol = rtol )
248
+ torch .testing .assert_close (ref_output , ref_output2 , atol = atol , rtol = rtol )
249
+ if not use_fp8_w8a8 :
250
+ torch .testing .assert_close (test_output , ref_output2 , atol = atol , rtol = rtol )
0 commit comments