Skip to content

Commit 2a05d6e

Browse files
committed
support already quantized input into fp8 kernel
Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com>
1 parent 2c1c8e3 commit 2a05d6e

File tree

1 file changed

+2
-18
lines changed

1 file changed

+2
-18
lines changed

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,24 +344,8 @@ def apply(
344344
out_dtype = input.dtype
345345

346346
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
347-
if self.cutlass_fp8_supported:
348-
assert input.dtype != current_platform.fp8_dtype(
349-
), "FP8 input to cutlass is not currently implemented"
350-
qinput, x_scale = ops.scaled_fp8_quant(
351-
input_2d,
352-
input_scale,
353-
scale_ub=input_scale_ub,
354-
use_per_token_if_dynamic=use_per_token_if_dynamic)
355-
else:
356-
if input.dtype != current_platform.fp8_dtype():
357-
# Maybe apply padding to output, see comment in __init__
358-
qinput, x_scale = ops.scaled_fp8_quant(
359-
input_2d,
360-
input_scale,
361-
num_token_padding=self.output_padding,
362-
use_per_token_if_dynamic=use_per_token_if_dynamic)
363-
else:
364-
qinput, x_scale = input_2d, input_scale
347+
input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device)
348+
qinput, x_scale = input_2d, input_scale
365349

366350
per_tensor_weights = (weight_scale.numel() == 1)
367351
per_tensor_activations = (x_scale.numel() == 1)

0 commit comments

Comments
 (0)