Skip to content

Commit 8e36b11

Browse files
committed
[float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes
And circumvent the issue with the slow CUTLASS kernel by using the cuBLAS kernel + manual scaling. ghstack-source-id: 3bcf2de Pull Request resolved: #1325
1 parent 06ad55a commit 8e36b11

File tree

2 files changed

+24
-35
lines changed

2 files changed

+24
-35
lines changed

torchao/float8/config.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ class Float8LinearConfig:
170170
#
171171
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
172172
# `grad_weight`
173-
# TODO(this PR): throw warning if fast_accum False is used with axiswise scaling
174173
#
175174
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
176175
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
@@ -320,21 +319,10 @@ def recipe_name_to_linear_config(
320319
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
321320
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
322321

323-
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
324-
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
325-
# accurate than tensorwise scaling, so the overall impact on accuracy
326-
# of tensorwise vs rowwise taking this flag into account will vary.
327-
gc_o = Float8GemmConfig(use_fast_accum=True)
328-
gc_gi = Float8GemmConfig(use_fast_accum=True)
329-
gc_gw = Float8GemmConfig(use_fast_accum=True)
330-
331322
return Float8LinearConfig(
332323
cast_config_input=cc_i,
333324
cast_config_weight=cc_w,
334325
cast_config_grad_output=cc_go,
335-
gemm_config_output=gc_o,
336-
gemm_config_grad_input=gc_gi,
337-
gemm_config_grad_weight=gc_gw,
338326
)
339327

340328
elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
@@ -362,24 +350,13 @@ def recipe_name_to_linear_config(
362350
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
363351
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
364352

365-
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
366-
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
367-
# accurate than tensorwise scaling, so the overall impact on accuracy
368-
# of tensorwise vs rowwise taking this flag into account will vary.
369-
gc_o = Float8GemmConfig(use_fast_accum=True)
370-
gc_gi = Float8GemmConfig(use_fast_accum=True)
371-
gc_gw = Float8GemmConfig(use_fast_accum=True)
372-
373353
return Float8LinearConfig(
374354
cast_config_input=cc_i,
375355
cast_config_weight=cc_w,
376356
cast_config_grad_output=cc_go,
377357
cast_config_input_for_grad_weight=cc_i_gw,
378358
cast_config_weight_for_grad_input=cc_w_gi,
379359
cast_config_grad_output_for_grad_weight=cc_go_gw,
380-
gemm_config_output=gc_o,
381-
gemm_config_grad_input=gc_gi,
382-
gemm_config_grad_weight=gc_gw,
383360
)
384361

385362
else:

torchao/float8/float8_python_api.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,25 @@ def addmm_float8_unwrapped(
3737
a_inverse_scale = a_scale.reciprocal()
3838
b_inverse_scale = b_scale.reciprocal()
3939

40-
if output_dtype == torch.float32 and bias is not None:
40+
post_inverse_scale = None
41+
if (
42+
a_scale.shape == (a_data.shape[0], 1)
43+
and b_scale.shape == (1, b_data.shape[1])
44+
and not use_fast_accum
45+
):
46+
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
47+
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
48+
# manually afterwards (hoping Inductor will be able to fuse it).
49+
post_inverse_scale = a_inverse_scale * b_inverse_scale
50+
a_inverse_scale = a_inverse_scale.new_ones(())
51+
b_inverse_scale = a_inverse_scale.new_ones(())
52+
53+
post_bias = None
54+
if output_dtype == torch.float32:
4155
# Bias is not supported by _scaled_mm when output is fp32
42-
output = torch._scaled_mm(
43-
a_data,
44-
b_data,
45-
scale_a=a_inverse_scale,
46-
scale_b=b_inverse_scale,
47-
scale_result=output_scale,
48-
out_dtype=output_dtype,
49-
use_fast_accum=use_fast_accum,
50-
)
51-
output += bias
52-
return output
56+
post_bias = bias
57+
bias = None
58+
5359
output = torch._scaled_mm(
5460
a_data,
5561
b_data,
@@ -60,4 +66,10 @@ def addmm_float8_unwrapped(
6066
out_dtype=output_dtype,
6167
use_fast_accum=use_fast_accum,
6268
)
69+
70+
if post_inverse_scale is not None:
71+
output *= post_inverse_scale
72+
if post_bias is not None:
73+
output += post_bias
74+
6375
return output

0 commit comments

Comments
 (0)