@@ -344,8 +344,28 @@ def apply(
344
344
out_dtype = input .dtype
345
345
346
346
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
347
- input_scale = torch .tensor ([1.0 ], dtype = torch .float32 , device = input_2d .device )
348
- qinput , x_scale = input_2d , input_scale
347
+ if self .cutlass_fp8_supported and input .dtype != current_platform .fp8_dtype ():
348
+ assert input .dtype != current_platform .fp8_dtype (
349
+ ), "FP8 input to cutlass is not currently implemented"
350
+ qinput , x_scale = ops .scaled_fp8_quant (
351
+ input_2d ,
352
+ input_scale ,
353
+ scale_ub = input_scale_ub ,
354
+ use_per_token_if_dynamic = use_per_token_if_dynamic )
355
+ else :
356
+ if input .dtype != current_platform .fp8_dtype ():
357
+ # Maybe apply padding to output, see comment in __init__
358
+ qinput , x_scale = ops .scaled_fp8_quant (
359
+ input_2d ,
360
+ input_scale ,
361
+ num_token_padding = self .output_padding ,
362
+ use_per_token_if_dynamic = use_per_token_if_dynamic )
363
+ else :
364
+ if x_scale is not None :
365
+ qinput , x_scale = input_2d , input_scale
366
+ else :
367
+ qinput = input_2d
368
+ x_scale = torch .tensor ([1.0 ], dtype = torch .float32 , device = input_2d .device )
349
369
350
370
per_tensor_weights = (weight_scale .numel () == 1 )
351
371
per_tensor_activations = (x_scale .numel () == 1 )
0 commit comments