@@ -170,7 +170,6 @@ class Float8LinearConfig:
170
170
#
171
171
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
172
172
# `grad_weight`
173
- # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling
174
173
#
175
174
gemm_config_output : Float8GemmConfig = Float8GemmConfig (use_fast_accum = True )
176
175
gemm_config_grad_input : Float8GemmConfig = Float8GemmConfig ()
@@ -320,21 +319,10 @@ def recipe_name_to_linear_config(
320
319
cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
321
320
cc_go = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
322
321
323
- # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
324
- # fast with `use_fast_accum=True`. Note that rowwise scaling is more
325
- # accurate than tensorwise scaling, so the overall impact on accuracy
326
- # of tensorwise vs rowwise taking this flag into account will vary.
327
- gc_o = Float8GemmConfig (use_fast_accum = True )
328
- gc_gi = Float8GemmConfig (use_fast_accum = True )
329
- gc_gw = Float8GemmConfig (use_fast_accum = True )
330
-
331
322
return Float8LinearConfig (
332
323
cast_config_input = cc_i ,
333
324
cast_config_weight = cc_w ,
334
325
cast_config_grad_output = cc_go ,
335
- gemm_config_output = gc_o ,
336
- gemm_config_grad_input = gc_gi ,
337
- gemm_config_grad_weight = gc_gw ,
338
326
)
339
327
340
328
elif recipe_name is Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP :
@@ -362,24 +350,13 @@ def recipe_name_to_linear_config(
362
350
cc_i_gw = CastConfig (scaling_type = ScalingType .DISABLED )
363
351
cc_go_gw = CastConfig (scaling_type = ScalingType .DISABLED )
364
352
365
- # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
366
- # fast with `use_fast_accum=True`. Note that rowwise scaling is more
367
- # accurate than tensorwise scaling, so the overall impact on accuracy
368
- # of tensorwise vs rowwise taking this flag into account will vary.
369
- gc_o = Float8GemmConfig (use_fast_accum = True )
370
- gc_gi = Float8GemmConfig (use_fast_accum = True )
371
- gc_gw = Float8GemmConfig (use_fast_accum = True )
372
-
373
353
return Float8LinearConfig (
374
354
cast_config_input = cc_i ,
375
355
cast_config_weight = cc_w ,
376
356
cast_config_grad_output = cc_go ,
377
357
cast_config_input_for_grad_weight = cc_i_gw ,
378
358
cast_config_weight_for_grad_input = cc_w_gi ,
379
359
cast_config_grad_output_for_grad_weight = cc_go_gw ,
380
- gemm_config_output = gc_o ,
381
- gemm_config_grad_input = gc_gi ,
382
- gemm_config_grad_weight = gc_gw ,
383
360
)
384
361
385
362
else :
0 commit comments