diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 939f68e59a..ddb6a2b5d7 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -183,6 +183,9 @@ class Float8LinearConfig: # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False + # If True, then pre compute the scale of the weights in fp8 linear module + pre_compute_fp8_all_gather_weights_scale: bool = False + # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.