File tree Expand file tree Collapse file tree 1 file changed +2
-18
lines changed
vllm/model_executor/layers/quantization/utils Expand file tree Collapse file tree 1 file changed +2
-18
lines changed Original file line number Diff line number Diff line change @@ -344,24 +344,8 @@ def apply(
344
344
out_dtype = input .dtype
345
345
346
346
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
347
- if self .cutlass_fp8_supported :
348
- assert input .dtype != current_platform .fp8_dtype (
349
- ), "FP8 input to cutlass is not currently implemented"
350
- qinput , x_scale = ops .scaled_fp8_quant (
351
- input_2d ,
352
- input_scale ,
353
- scale_ub = input_scale_ub ,
354
- use_per_token_if_dynamic = use_per_token_if_dynamic )
355
- else :
356
- if input .dtype != current_platform .fp8_dtype ():
357
- # Maybe apply padding to output, see comment in __init__
358
- qinput , x_scale = ops .scaled_fp8_quant (
359
- input_2d ,
360
- input_scale ,
361
- num_token_padding = self .output_padding ,
362
- use_per_token_if_dynamic = use_per_token_if_dynamic )
363
- else :
364
- qinput , x_scale = input_2d , input_scale
347
+ input_scale = torch .tensor ([1.0 ], dtype = torch .float32 , device = input_2d .device )
348
+ qinput , x_scale = input_2d , input_scale
365
349
366
350
per_tensor_weights = (weight_scale .numel () == 1 )
367
351
per_tensor_activations = (x_scale .numel () == 1 )
You can’t perform that action at this time.
0 commit comments