7
7
8
8
from vllm .platforms import current_platform
9
9
from vllm .utils import cdiv
10
+ from vllm .model_executor .layers .quantization .utils .quant_utils import group_broadcast
10
11
11
12
# Using the default value (240.0) from pytorch will cause accuracy
12
13
# issue on dynamic quantization models. Here use 224.0 for rocm.
@@ -235,14 +236,23 @@ def per_block_cast_to_fp8(
235
236
return x_scaled_sub , scales
236
237
237
238
239
+ def _dequant (t : torch .Tensor , scale : torch .Tensor , block_shape , per_act_token_quant ) -> torch .Tensor :
240
+ f32 = torch .float32
241
+ if per_act_token_quant or block_shape is None :
242
+ return t .to (f32 ) * scale
243
+ else :
244
+ return t .to (f32 ) * group_broadcast (scale , t .shape )
245
+
246
+
238
247
def native_batched_masked_quant_matmul (
239
248
A : torch .Tensor ,
240
249
B : torch .Tensor ,
241
250
C : torch .Tensor ,
242
251
num_expert_tokens : torch .Tensor ,
243
- A_scale : Optional [torch .Tensor ],
244
- B_scale : Optional [torch .Tensor ],
245
- block_shape : Optional [list [int ]],
252
+ A_scale : Optional [torch .Tensor ] = None ,
253
+ B_scale : Optional [torch .Tensor ] = None ,
254
+ block_shape : Optional [list [int ]] = None ,
255
+ per_act_token_quant : bool = False ,
246
256
) -> torch .Tensor :
247
257
num_expert_tokens_cpu = num_expert_tokens .clone ()
248
258
num_expert_tokens_cpu = num_expert_tokens_cpu .to (device = "cpu" )
@@ -259,9 +269,9 @@ def native_batched_masked_quant_matmul(
259
269
C [e , :num_tokens , :] = tmp [:num_tokens , :]
260
270
elif A .dtype .itemsize == 1 and block_shape is None :
261
271
assert A_scale is not None and B_scale is not None
262
- C [ e , : num_tokens , :] = (
263
- ( A [ e , : num_tokens , :]. to ( f32 ) * A_scale [e ]). to ( C . dtype )
264
- @ ( B [ e ] .transpose (0 , 1 ). to ( f32 ) * B_scale [ e ]) .to (C .dtype ) )
272
+ A_dq = _dequant ( A [ e ], A_scale [ e ], block_shape , per_act_token_quant )
273
+ B_dq = _dequant ( B [ e ], B_scale [e ], block_shape , per_act_token_quant )
274
+ C [ e , : num_tokens , :] = ( A_dq [: num_tokens ] @ B_dq .transpose (0 , 1 )) .to (C .dtype )
265
275
else :
266
276
assert A_scale is None
267
277
assert B_scale is None
0 commit comments