Skip to content

Commit c887b19

Browse files
authored
Fp8 cleanup (#4)
Small cleanup to still keep the existing codepath if non fp8 input is passed in --------- Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com>
1 parent c1633da commit c887b19

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,28 @@ def apply(
344344
out_dtype = input.dtype
345345

346346
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
347-
input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device)
348-
qinput, x_scale = input_2d, input_scale
347+
if self.cutlass_fp8_supported and input.dtype != current_platform.fp8_dtype():
348+
assert input.dtype != current_platform.fp8_dtype(
349+
), "FP8 input to cutlass is not currently implemented"
350+
qinput, x_scale = ops.scaled_fp8_quant(
351+
input_2d,
352+
input_scale,
353+
scale_ub=input_scale_ub,
354+
use_per_token_if_dynamic=use_per_token_if_dynamic)
355+
else:
356+
if input.dtype != current_platform.fp8_dtype():
357+
# Maybe apply padding to output, see comment in __init__
358+
qinput, x_scale = ops.scaled_fp8_quant(
359+
input_2d,
360+
input_scale,
361+
num_token_padding=self.output_padding,
362+
use_per_token_if_dynamic=use_per_token_if_dynamic)
363+
else:
364+
if x_scale is not None:
365+
qinput, x_scale = input_2d, input_scale
366+
else:
367+
qinput = input_2d
368+
x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device)
349369

350370
per_tensor_weights = (weight_scale.numel() == 1)
351371
per_tensor_activations = (x_scale.numel() == 1)

0 commit comments

Comments
 (0)