diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 8d4d58fd6b..c8812e29ec 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -4,6 +4,7 @@ Float8GemmConfig, Float8LinearConfig, ScalingType, + Float8LinearRecipeName, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -38,6 +39,7 @@ "Float8GemmConfig", "Float8LinearConfig", "CastConfig", + "Float8LinearRecipeName", # top level UX "convert_to_float8_training", "precompute_float8_dynamic_scale_for_fsdp",